From d40efe897f8a9baf91ddb6a73def321e05b35eca Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 6 Mar 2026 18:08:42 +0800 Subject: [PATCH] ltx2.3 train --- README.md | 2 +- README_zh.md | 2 +- diffsynth/configs/model_configs.py | 53 +++++++++ .../state_dict_converters/ltx2_audio_vae.py | 1 - docs/en/Model_Details/LTX-2.md | 2 +- docs/zh/Model_Details/LTX-2.md | 2 +- .../model_inference/LTX-2.3-T2AV-OneStage.py | 17 ++- .../model_training/full/LTX-2-T2AV-splited.sh | 4 +- .../full/LTX-2.3-T2AV-splited.sh | 35 ++++++ .../lora/LTX-2.3-T2AV-splited.sh | 39 +++++++ .../scripts/split_model_statedicts_ltx2.3.py | 102 ++++++++++++++++++ .../validate_full/LTX-2.3-T2AV.py | 47 ++++++++ .../validate_lora/LTX-2.3-T2AV.py | 48 +++++++++ 13 files changed, 346 insertions(+), 8 deletions(-) create mode 100644 examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh create mode 100644 examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh create mode 100644 examples/ltx2/model_training/scripts/split_model_statedicts_ltx2.3.py create mode 100644 examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py create mode 100644 examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py diff --git a/README.md b/README.md index 5a89111..2001bc3 100644 --- a/README.md +++ b/README.md @@ -705,7 +705,7 @@ Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/) | Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)| |[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|-|-|-|-| diff --git a/README_zh.md b/README_zh.md index fec510a..e862738 100644 --- a/README_zh.md +++ b/README_zh.md @@ -705,7 +705,7 @@ LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/) |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)| |[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|-|-|-|-| diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 3c00dca..93d1182 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -778,6 +778,59 @@ ltx2_series = [ "model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler", "extra_kwargs": {"rational_resampler": False}, }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="transformer.safetensors") + "model_hash": "1c55afad76ed33c112a2978550b524d1", + "model_name": "ltx2_dit", + "model_class": "diffsynth.models.ltx2_dit.LTXModel", + "extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_encoder.safetensors") + "model_hash": "eecdc07c2ec30863b8a2b8b2134036cf", + "model_name": "ltx2_video_vae_encoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder", + "extra_kwargs": {"encoder_version": "ltx-2.3"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors") + "model_hash": "deda2f542e17ee25bc8c38fd605316ea", + "model_name": "ltx2_video_vae_decoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder", + "extra_kwargs": {"decoder_version": "ltx-2.3"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors") + "model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb", + "model_name": "ltx2_audio_vae_decoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_encoder.safetensors") + "model_hash": "29338f3b95e7e312a3460a482e4f4554", + "model_name": "ltx2_audio_vae_encoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors") + "model_hash": "cd436c99e69ec5c80f050f0944f02a15", + "model_name": "ltx2_audio_vocoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors") + "model_hash": "05da2aab1c4b061f72c426311c165a43", + "model_name": "ltx2_text_encoder_post_modules", + "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules", + "extra_kwargs": {"separated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attention_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connector_layers": 8, "apply_gated_attention": True}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter", + }, ] anima_series = [ { diff --git a/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py b/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py index dc2622c..0218530 100644 --- a/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py +++ b/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py @@ -27,7 +27,6 @@ def LTX2VocoderStateDictConverter(state_dict): state_dict_ = {} for name in state_dict: if name.startswith("vocoder."): - # new_name = name.replace("vocoder.", "") new_name = name[len("vocoder."):] state_dict_[new_name] = state_dict[name] return state_dict_ diff --git a/docs/en/Model_Details/LTX-2.md b/docs/en/Model_Details/LTX-2.md index 007d0b7..889e3fd 100644 --- a/docs/en/Model_Details/LTX-2.md +++ b/docs/en/Model_Details/LTX-2.md @@ -111,7 +111,7 @@ write_video_audio_ltx2( ## Model Overview |Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training| |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)| |[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|-|-|-|-| diff --git a/docs/zh/Model_Details/LTX-2.md b/docs/zh/Model_Details/LTX-2.md index 66755db..18acefd 100644 --- a/docs/zh/Model_Details/LTX-2.md +++ b/docs/zh/Model_Details/LTX-2.md @@ -111,7 +111,7 @@ write_video_audio_ltx2( ## 模型总览 |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)| |[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|-|-|-|-| diff --git a/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py b/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py index f8a3d58..4311387 100644 --- a/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py +++ b/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py @@ -12,15 +12,30 @@ vram_config = { "computation_dtype": torch.bfloat16, "computation_device": "cuda", } +# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2.3-Repackage" to avoid redundant model loading pipe = LTX2AudioVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), ) +# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2.3" +# pipe = LTX2AudioVideoPipeline.from_pretrained( +# torch_dtype=torch.bfloat16, +# device="cuda", +# model_configs=[ +# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config), +# ], +# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +# ) prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." height, width, num_frames = 512, 768, 121 diff --git a/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh b/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh index 04f3b1c..2d37718 100644 --- a/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh +++ b/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh @@ -9,7 +9,7 @@ accelerate launch examples/ltx2/model_training/train.py \ --num_frames 121 \ --dataset_repeat 1 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ - --learning_rate 1e-4 \ + --learning_rate 1e-5 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/LTX2-T2AV-full-splited-cache" \ @@ -26,7 +26,7 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera --num_frames 121 \ --dataset_repeat 100 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \ - --learning_rate 1e-4 \ + --learning_rate 1e-5 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/LTX2-T2AV-full" \ diff --git a/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh b/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh new file mode 100644 index 0000000..4d02da2 --- /dev/null +++ b/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh @@ -0,0 +1,35 @@ +# Splited Training +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio" \ + --height 512 \ + --width 768 \ + --num_frames 121 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2.3-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2.3-T2AV-full-splited-cache" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --task "sft:data_process" + +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/ltx2/model_training/train.py \ + --dataset_base_path ./models/train/LTX2.3-T2AV-full-splited-cache \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio" \ + --height 512 \ + --width 768 \ + --num_frames 121 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2.3-Repackage:transformer.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2.3-T2AV-full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --task "sft:train" diff --git a/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh b/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh new file mode 100644 index 0000000..038d660 --- /dev/null +++ b/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh @@ -0,0 +1,39 @@ +# Splited Training +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio" \ + --height 512 \ + --width 768 \ + --num_frames 121 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2.3-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2.3-T2AV_lora-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_k,to_q,to_v,to_out.0" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --task "sft:data_process" + +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path ./models/train/LTX2.3-T2AV_lora-splited-cache \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio" \ + --height 512 \ + --width 768 \ + --num_frames 121 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2.3-Repackage:transformer.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2.3-T2AV_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_k,to_q,to_v,to_out.0" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --task "sft:train" diff --git a/examples/ltx2/model_training/scripts/split_model_statedicts_ltx2.3.py b/examples/ltx2/model_training/scripts/split_model_statedicts_ltx2.3.py new file mode 100644 index 0000000..1dbc1bb --- /dev/null +++ b/examples/ltx2/model_training/scripts/split_model_statedicts_ltx2.3.py @@ -0,0 +1,102 @@ +from safetensors.torch import save_file +from diffsynth import hash_state_dict_keys +from diffsynth.core import load_state_dict +from diffsynth.models.model_loader import ModelPool +import os + +model_pool = ModelPool() +state_dict = load_state_dict("models/Lightricks/LTX-2.3/ltx-2.3-22b-dev.safetensors") +os.makedirs("models/DiffSynth-Studio/LTX-2.3-Repackage", exist_ok=True) + +dit_state_dict = {} +for name in state_dict: + if name.startswith("model.diffusion_model."): + new_name = name.replace("model.diffusion_model.", "") + if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."): + continue + dit_state_dict[name] = state_dict[name] + +print(f"dit_state_dict keys hash: {hash_state_dict_keys(dit_state_dict)}") +save_file(dit_state_dict, "models/DiffSynth-Studio/LTX-2.3-Repackage/transformer.safetensors") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2.3-Repackage/transformer.safetensors") + + +video_vae_encoder_state_dict = {} +for name in state_dict: + if name.startswith("vae.encoder."): + video_vae_encoder_state_dict[name] = state_dict[name] + elif name.startswith("vae.per_channel_statistics."): + video_vae_encoder_state_dict[name] = state_dict[name] + +save_file(video_vae_encoder_state_dict, "models/DiffSynth-Studio/LTX-2.3-Repackage/video_vae_encoder.safetensors") +print(f"video_vae_encoder keys hash: {hash_state_dict_keys(video_vae_encoder_state_dict)}") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2.3-Repackage/video_vae_encoder.safetensors") + + +video_vae_decoder_state_dict = {} +for name in state_dict: + if name.startswith("vae.decoder."): + video_vae_decoder_state_dict[name] = state_dict[name] + elif name.startswith("vae.per_channel_statistics."): + video_vae_decoder_state_dict[name] = state_dict[name] +save_file(video_vae_decoder_state_dict, "models/DiffSynth-Studio/LTX-2.3-Repackage/video_vae_decoder.safetensors") +print(f"video_vae_decoder keys hash: {hash_state_dict_keys(video_vae_decoder_state_dict)}") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2.3-Repackage/video_vae_decoder.safetensors") + + +audio_vae_decoder_state_dict = {} +for name in state_dict: + if name.startswith("audio_vae.decoder."): + audio_vae_decoder_state_dict[name] = state_dict[name] + elif name.startswith("audio_vae.per_channel_statistics."): + audio_vae_decoder_state_dict[name] = state_dict[name] +save_file(audio_vae_decoder_state_dict, "models/DiffSynth-Studio/LTX-2.3-Repackage/audio_vae_decoder.safetensors") +print(f"audio_vae_decoder keys hash: {hash_state_dict_keys(audio_vae_decoder_state_dict)}") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2.3-Repackage/audio_vae_decoder.safetensors") + + +audio_vae_encoder_state_dict = {} +for name in state_dict: + if name.startswith("audio_vae.encoder."): + audio_vae_encoder_state_dict[name] = state_dict[name] + elif name.startswith("audio_vae.per_channel_statistics."): + audio_vae_encoder_state_dict[name] = state_dict[name] +save_file(audio_vae_encoder_state_dict, "models/DiffSynth-Studio/LTX-2.3-Repackage/audio_vae_encoder.safetensors") +print(f"audio_vae_encoder keys hash: {hash_state_dict_keys(audio_vae_encoder_state_dict)}") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2.3-Repackage/audio_vae_encoder.safetensors") + + +audio_vocoder_state_dict = {} +for name in state_dict: + if name.startswith("vocoder."): + audio_vocoder_state_dict[name] = state_dict[name] +save_file(audio_vocoder_state_dict, "models/DiffSynth-Studio/LTX-2.3-Repackage/audio_vocoder.safetensors") +print(f"audio_vocoder keys hash: {hash_state_dict_keys(audio_vocoder_state_dict)}") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2.3-Repackage/audio_vocoder.safetensors") + + +text_encoder_post_modules_state_dict = {} +for name in state_dict: + if name.startswith("text_embedding_projection."): + text_encoder_post_modules_state_dict[name] = state_dict[name] + elif name.startswith("model.diffusion_model.video_embeddings_connector."): + text_encoder_post_modules_state_dict[name] = state_dict[name] + elif name.startswith("model.diffusion_model.audio_embeddings_connector."): + text_encoder_post_modules_state_dict[name] = state_dict[name] +save_file(text_encoder_post_modules_state_dict, "models/DiffSynth-Studio/LTX-2.3-Repackage/text_encoder_post_modules.safetensors") +print(f"text_encoder_post_modules keys hash: {hash_state_dict_keys(text_encoder_post_modules_state_dict)}") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2.3-Repackage/text_encoder_post_modules.safetensors") + + +state_dict = load_state_dict("models/Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors") +dit_state_dict = {} +for name in state_dict: + if name.startswith("model.diffusion_model."): + new_name = name.replace("model.diffusion_model.", "") + if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."): + continue + dit_state_dict[name] = state_dict[name] + +print(f"dit_state_dict keys hash: {hash_state_dict_keys(dit_state_dict)}") +save_file(dit_state_dict, "models/DiffSynth-Studio/LTX-2.3-Repackage/transformer_distilled.safetensors") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2.3-Repackage/transformer_distilled.safetensors") diff --git a/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py b/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py new file mode 100644 index 0000000..d839254 --- /dev/null +++ b/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py @@ -0,0 +1,47 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(path="./models/train/LTX2.3-T2AV-full/epoch-4.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) +prompt = "A beautiful sunset over the ocean." +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + cfg_scale=4.0 +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py b/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py new file mode 100644 index 0000000..03d974b --- /dev/null +++ b/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py @@ -0,0 +1,48 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) +pipe.load_lora(pipe.dit, "models/train/LTX2.3-T2AV_lora/epoch-4.safetensors") +prompt = "A beautiful sunset over the ocean." +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + cfg_scale=4.0 +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage.mp4', + fps=24, + audio_sample_rate=pipe.audio_vocoder.output_sampling_rate, +)