add audio_vae, audio_vocoder, text_encoder, connector and upsampler for ltx2

This commit is contained in:
mi804
2026-01-28 16:09:22 +08:00
parent 00da4b6c4f
commit 8d303b47e9
8 changed files with 2207 additions and 24 deletions

View File

@@ -0,0 +1,31 @@
def LTX2TextEncoderStateDictConverter(state_dict):
state_dict_ = {}
for key in state_dict:
if key.startswith("language_model.model."):
new_key = key.replace("language_model.model.", "model.language_model.")
elif key.startswith("vision_tower."):
new_key = key.replace("vision_tower.", "model.vision_tower.")
elif key.startswith("multi_modal_projector."):
new_key = key.replace("multi_modal_projector.", "model.multi_modal_projector.")
elif key.startswith("language_model.lm_head."):
new_key = key.replace("language_model.lm_head.", "lm_head.")
else:
continue
state_dict_[new_key] = state_dict[key]
state_dict_["lm_head.weight"] = state_dict_.get("model.language_model.embed_tokens.weight")
return state_dict_
def LTX2TextEncoderPostModulesStateDictConverter(state_dict):
state_dict_ = {}
for key in state_dict:
if key.startswith("text_embedding_projection."):
new_key = key.replace("text_embedding_projection.", "feature_extractor_linear.")
elif key.startswith("model.diffusion_model.video_embeddings_connector."):
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.")
elif key.startswith("model.diffusion_model.audio_embeddings_connector."):
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.")
else:
continue
state_dict_[new_key] = state_dict[key]
return state_dict_