From 6383ec358cd26ae5f70fd1d11adc39074273aadb Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Sat, 7 Feb 2026 05:23:11 +0530 Subject: [PATCH 01/23] Fix AttributeError when pipe.dit is None When using split training with 'sft:data_process' task, the DiT model is not loaded but the attribute 'dit' exists with value None. The existing hasattr check returns True but then accessing siglip_embedder fails. Add an explicit None check before accessing pipe.dit.siglip_embedder. Fixes #1246 --- diffsynth/pipelines/z_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index 2c5b687..32089b1 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -296,7 +296,7 @@ class ZImageUnit_PromptEmbedder(PipelineUnit): def process(self, pipe: ZImagePipeline, prompt, edit_image): pipe.load_models_to_device(self.onload_model_names) - if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None: + if hasattr(pipe, "dit") and pipe.dit is not None and pipe.dit.siglip_embedder is not None: # Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods. # We determine which encoding method to use based on the model architecture. # If you are using two-stage split training, From 0e6976a0ae1e21062f697e294042d4edab8c581f Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Wed, 11 Feb 2026 19:51:25 +0530 Subject: [PATCH 02/23] fix: prevent division by zero in trajectory imitation loss at last step --- diffsynth/diffusion/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py index 14fdfd3..4da195f 100644 --- a/diffsynth/diffusion/loss.py +++ b/diffsynth/diffusion/loss.py @@ -91,7 +91,7 @@ class TrajectoryImitationLoss(torch.nn.Module): progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs()) latents_ = trajectory_teacher[progress_id_teacher] - target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma) + target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma).clamp(min=1e-6) loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep) return loss From b68663426ff59a245a56e7cbf2b6b6bf6e7f879d Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Wed, 11 Feb 2026 21:04:55 +0530 Subject: [PATCH 03/23] fix: preserve sign of denominator in clamp to avoid inverting gradient direction The previous .clamp(min=1e-6) on (sigma_ - sigma) flips the sign when the denominator is negative (which is the typical case since sigmas decrease monotonically). This would invert the target and cause training divergence. Use torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6) instead, which prevents division by zero while preserving the correct sign. --- diffsynth/diffusion/loss.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py index 4da195f..065e589 100644 --- a/diffsynth/diffusion/loss.py +++ b/diffsynth/diffusion/loss.py @@ -91,7 +91,9 @@ class TrajectoryImitationLoss(torch.nn.Module): progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs()) latents_ = trajectory_teacher[progress_id_teacher] - target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma).clamp(min=1e-6) + denom = sigma_ - sigma + denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6) + target = (latents_ - inputs_shared["latents"]) / denom loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep) return loss From 96fb0f3afea01ffb773a9d2faed26178ffd18478 Mon Sep 17 00:00:00 2001 From: Mr_Dwj Date: Thu, 12 Feb 2026 23:51:56 +0800 Subject: [PATCH 04/23] fix: unpack Resample38 output --- diffsynth/models/wan_video_vae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index 3c2181a..b77f75c 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -469,7 +469,7 @@ class Down_ResidualBlock(nn.Module): def forward(self, x, feat_cache=None, feat_idx=[0]): x_copy = x.clone() for module in self.downsamples: - x = module(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = module(x, feat_cache, feat_idx) return x + self.avg_shortcut(x_copy), feat_cache, feat_idx @@ -506,7 +506,7 @@ class Up_ResidualBlock(nn.Module): def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): x_main = x.clone() for module in self.upsamples: - x_main = module(x_main, feat_cache, feat_idx) + x_main, feat_cache, feat_idx = module(x_main, feat_cache, feat_idx) if self.avg_shortcut is not None: x_shortcut = self.avg_shortcut(x, first_chunk) return x_main + x_shortcut From bd3c5822a1f9fda25163796ba74c139113b963a4 Mon Sep 17 00:00:00 2001 From: Mr_Dwj Date: Fri, 13 Feb 2026 01:13:08 +0800 Subject: [PATCH 05/23] fix: WanVAE2.2 decode error --- diffsynth/models/wan_video_vae.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index b77f75c..19ab6bd 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -509,7 +509,7 @@ class Up_ResidualBlock(nn.Module): x_main, feat_cache, feat_idx = module(x_main, feat_cache, feat_idx) if self.avg_shortcut is not None: x_shortcut = self.avg_shortcut(x, first_chunk) - return x_main + x_shortcut + return x_main + x_shortcut, feat_cache, feat_idx else: return x_main, feat_cache, feat_idx @@ -1336,6 +1336,7 @@ class VideoVAE38_(VideoVAE_): x = self.conv2(z) for i in range(iter_): self._conv_idx = [0] + # breakpoint() if i == 0: out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, From fc11fd42974bc5bda57160a687aab41553bd2d85 Mon Sep 17 00:00:00 2001 From: Mr_Dwj Date: Fri, 13 Feb 2026 09:38:14 +0800 Subject: [PATCH 06/23] chore: remove invalid comment code --- diffsynth/models/wan_video_vae.py | 1 - 1 file changed, 1 deletion(-) diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index 19ab6bd..3d5db68 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -1336,7 +1336,6 @@ class VideoVAE38_(VideoVAE_): x = self.conv2(z) for i in range(iter_): self._conv_idx = [0] - # breakpoint() if i == 0: out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, From 71cea4371c2bc68fdc7d346eed0ce2af212a6a26 Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Fri, 13 Feb 2026 09:58:27 +0800 Subject: [PATCH 07/23] [doc][NPU]Documentation on modifications, NPU environment installation, and additional parameter --- docs/en/Pipeline_Usage/GPU_support.md | 3 ++- docs/en/Pipeline_Usage/Setup.md | 4 ++-- docs/zh/Pipeline_Usage/GPU_support.md | 3 ++- docs/zh/Pipeline_Usage/Setup.md | 4 ++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/en/Pipeline_Usage/GPU_support.md b/docs/en/Pipeline_Usage/GPU_support.md index aba5706..d1e77ef 100644 --- a/docs/en/Pipeline_Usage/GPU_support.md +++ b/docs/en/Pipeline_Usage/GPU_support.md @@ -81,4 +81,5 @@ Set 0 or not set: indicates not enabling the binding function #### Parameters for specific models | Model | Parameter | Note | |----------------|---------------------------|-------------------| -| Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU | \ No newline at end of file +| Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU | +| Z-Image series | --enable_npu_patch | Using NPU fusion operator to replace the corresponding operator in Z-image model to improve the performance of the model on NPU | \ No newline at end of file diff --git a/docs/en/Pipeline_Usage/Setup.md b/docs/en/Pipeline_Usage/Setup.md index 31fb771..dc06364 100644 --- a/docs/en/Pipeline_Usage/Setup.md +++ b/docs/en/Pipeline_Usage/Setup.md @@ -37,9 +37,9 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6 git clone https://github.com/modelscope/DiffSynth-Studio.git cd DiffSynth-Studio # aarch64/ARM - pip install -e .[npu_aarch64] --extra-index-url "https://download.pytorch.org/whl/cpu" + pip install -e .[npu_aarch64] # x86 - pip install -e .[npu] + pip install -e .[npu] --extra-index-url "https://download.pytorch.org/whl/cpu" When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](/docs/en/Pipeline_Usage/GPU_support.md#ascend-npu). diff --git a/docs/zh/Pipeline_Usage/GPU_support.md b/docs/zh/Pipeline_Usage/GPU_support.md index 8124147..7a66923 100644 --- a/docs/zh/Pipeline_Usage/GPU_support.md +++ b/docs/zh/Pipeline_Usage/GPU_support.md @@ -81,4 +81,5 @@ export CPU_AFFINITY_CONF=1 #### 特定模型需要开启的参数 | 模型 | 参数 | 备注 | |-----------|------|-------------------| -| Wan 14B系列 | --initialize_model_on_cpu | 14B模型需要在cpu上进行初始化 | \ No newline at end of file +| Wan 14B系列 | --initialize_model_on_cpu | 14B模型需要在cpu上进行初始化 | +| Z-Image 系列 | --enable_npu_patch | 使用NPU融合算子来替换Z-image模型中的对应算子以提升模型在NPU上的性能 | \ No newline at end of file diff --git a/docs/zh/Pipeline_Usage/Setup.md b/docs/zh/Pipeline_Usage/Setup.md index e4a022c..9823593 100644 --- a/docs/zh/Pipeline_Usage/Setup.md +++ b/docs/zh/Pipeline_Usage/Setup.md @@ -37,9 +37,9 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6 git clone https://github.com/modelscope/DiffSynth-Studio.git cd DiffSynth-Studio # aarch64/ARM - pip install -e .[npu_aarch64] --extra-index-url "https://download.pytorch.org/whl/cpu" + pip install -e .[npu_aarch64] # x86 - pip install -e .[npu] + pip install -e .[npu] --extra-index-url "https://download.pytorch.org/whl/cpu" 使用 Ascend NPU 时,请将 Python 代码中的 `"cuda"` 改为 `"npu"`,详见[NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md#ascend-npu)。 From 586ac9d8a69353875c8d590427590308dba958d6 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 25 Feb 2026 17:19:57 +0800 Subject: [PATCH 08/23] support ltx-2 training --- diffsynth/configs/model_configs.py | 56 +++++- diffsynth/diffusion/loss.py | 30 ++++ diffsynth/models/ltx2_audio_vae.py | 57 ++++++ diffsynth/models/ltx2_dit.py | 2 +- diffsynth/pipelines/ltx2_audio_video.py | 65 +++++-- .../LTX-2-I2AV-DistilledPipeline.py | 7 +- .../model_inference/LTX-2-I2AV-OneStage.py | 7 +- .../model_inference/LTX-2-I2AV-TwoStage.py | 7 +- .../LTX-2-T2AV-Camera-Control-Dolly-In.py | 6 +- .../LTX-2-T2AV-Camera-Control-Dolly-Left.py | 6 +- .../LTX-2-T2AV-Camera-Control-Dolly-Out.py | 6 +- .../LTX-2-T2AV-Camera-Control-Dolly-Right.py | 6 +- .../LTX-2-T2AV-Camera-Control-Jib-Down.py | 6 +- .../LTX-2-T2AV-Camera-Control-Jib-Up.py | 6 +- .../LTX-2-T2AV-Camera-Control-Static.py | 6 +- .../LTX-2-T2AV-DistilledPipeline.py | 6 +- .../model_inference/LTX-2-T2AV-OneStage.py | 15 +- .../model_inference/LTX-2-T2AV-TwoStage.py | 6 +- .../LTX-2-I2AV-DistilledPipeline.py | 7 +- .../LTX-2-I2AV-OneStage.py | 7 +- .../LTX-2-I2AV-TwoStage.py | 7 +- .../LTX-2-T2AV-Camera-Control-Dolly-In.py | 6 +- .../LTX-2-T2AV-Camera-Control-Dolly-Left.py | 6 +- .../LTX-2-T2AV-Camera-Control-Dolly-Out.py | 6 +- .../LTX-2-T2AV-Camera-Control-Dolly-Right.py | 6 +- .../LTX-2-T2AV-Camera-Control-Jib-Down.py | 6 +- .../LTX-2-T2AV-Camera-Control-Jib-Up.py | 6 +- .../LTX-2-T2AV-Camera-Control-Static.py | 6 +- .../LTX-2-T2AV-DistilledPipeline.py | 6 +- .../LTX-2-T2AV-OneStage.py | 6 +- .../LTX-2-T2AV-TwoStage.py | 6 +- .../model_training/full/LTX-2-T2AV-splited.sh | 35 ++++ .../model_training/lora/LTX-2-T2AV-noaudio.sh | 56 ++++++ .../model_training/lora/LTX-2-T2AV-splited.sh | 40 +++++ .../ltx2/model_training/lora/LTX-2-T2AV.sh | 19 ++ examples/ltx2/model_training/train.py | 162 ++++++++++++++++++ .../validate_full/LTX-2-T2AV.py | 47 +++++ .../validate_lora/LTX-2-T2AV.py | 49 ++++++ .../validate_lora/LTX-2-T2AV_noaudio.py | 49 ++++++ .../ltx2/scripts/split_model_statedicts.py | 104 +++++++++++ 40 files changed, 893 insertions(+), 49 deletions(-) create mode 100644 examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh create mode 100644 examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh create mode 100644 examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh create mode 100644 examples/ltx2/model_training/lora/LTX-2-T2AV.sh create mode 100644 examples/ltx2/model_training/train.py create mode 100644 examples/ltx2/model_training/validate_full/LTX-2-T2AV.py create mode 100644 examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py create mode 100644 examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py create mode 100644 examples/ltx2/scripts/split_model_statedicts.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 9ff7ea6..dbad638 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -607,6 +607,12 @@ ltx2_series = [ "model_class": "diffsynth.models.ltx2_dit.LTXModel", "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter", }, + { + "model_hash": "c567aaa37d5ed7454c73aa6024458661", + "model_name": "ltx2_dit", + "model_class": "diffsynth.models.ltx2_dit.LTXModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter", + }, { # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") "model_hash": "aca7b0bbf8415e9c98360750268915fc", @@ -614,6 +620,12 @@ ltx2_series = [ "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder", "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter", }, + { + "model_hash": "7f7e904a53260ec0351b05f32153754b", + "model_name": "ltx2_video_vae_encoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter", + }, { # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") "model_hash": "aca7b0bbf8415e9c98360750268915fc", @@ -621,6 +633,12 @@ ltx2_series = [ "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder", "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter", }, + { + "model_hash": "dc6029ca2825147872b45e35a2dc3a97", + "model_name": "ltx2_video_vae_decoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter", + }, { # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") "model_hash": "aca7b0bbf8415e9c98360750268915fc", @@ -628,6 +646,12 @@ ltx2_series = [ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder", "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter", }, + { + "model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb", + "model_name": "ltx2_audio_vae_decoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter", + }, { # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") "model_hash": "aca7b0bbf8415e9c98360750268915fc", @@ -635,16 +659,34 @@ ltx2_series = [ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder", "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter", }, - # { # not used currently - # # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") - # "model_hash": "aca7b0bbf8415e9c98360750268915fc", - # "model_name": "ltx2_audio_vae_encoder", - # "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder", - # "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter", - # }, + { + "model_hash": "f471360f6b24bef702ab73133d9f8bb9", + "model_name": "ltx2_audio_vocoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter", + }, { # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_audio_vae_encoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter", + }, + { + "model_hash": "29338f3b95e7e312a3460a482e4f4554", + "model_name": "ltx2_audio_vae_encoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_text_encoder_post_modules", + "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter", + }, + { + "model_hash": "981629689c8be92a712ab3c5eb4fc3f6", "model_name": "ltx2_text_encoder_post_modules", "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules", "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter", diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py index 14fdfd3..c9330e5 100644 --- a/diffsynth/diffusion/loss.py +++ b/diffsynth/diffusion/loss.py @@ -28,6 +28,36 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): return loss +def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs): + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) + min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) + + timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) + timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) + + # video + noise = torch.randn_like(inputs["input_latents"]) + inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) + training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) + + # audio + if inputs.get("audio_input_latents") is not None: + audio_noise = torch.randn_like(inputs["audio_input_latents"]) + inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep) + training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep) + + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * pipe.scheduler.training_weight(timestep) + if inputs.get("audio_input_latents") is not None: + loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float()) + loss_audio = loss_audio * pipe.scheduler.training_weight(timestep) + loss = loss + loss_audio + return loss + + def DirectDistillLoss(pipe: BasePipeline, **inputs): pipe.scheduler.set_timesteps(inputs["num_inference_steps"]) pipe.scheduler.training = True diff --git a/diffsynth/models/ltx2_audio_vae.py b/diffsynth/models/ltx2_audio_vae.py index 708ded7..0a12cb5 100644 --- a/diffsynth/models/ltx2_audio_vae.py +++ b/diffsynth/models/ltx2_audio_vae.py @@ -5,8 +5,65 @@ import einops import torch import torch.nn as nn import torch.nn.functional as F +import torchaudio from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer + +class AudioProcessor(nn.Module): + """Converts audio waveforms to log-mel spectrograms with optional resampling.""" + + def __init__( + self, + sample_rate: int = 16000, + mel_bins: int = 64, + mel_hop_length: int = 160, + n_fft: int = 1024, + ) -> None: + super().__init__() + self.sample_rate = sample_rate + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + win_length=n_fft, + hop_length=mel_hop_length, + f_min=0.0, + f_max=sample_rate / 2.0, + n_mels=mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale="slaney", + norm="slaney", + ) + + def resample_waveform( + self, + waveform: torch.Tensor, + source_rate: int, + target_rate: int, + ) -> torch.Tensor: + """Resample waveform to target sample rate if needed.""" + if source_rate == target_rate: + return waveform + resampled = torchaudio.functional.resample(waveform, source_rate, target_rate) + return resampled.to(device=waveform.device, dtype=waveform.dtype) + + def waveform_to_mel( + self, + waveform: torch.Tensor, + waveform_sample_rate: int, + ) -> torch.Tensor: + """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels].""" + waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate) + + mel = self.mel_transform(waveform) + mel = torch.log(torch.clamp(mel, min=1e-5)) + + mel = mel.to(device=waveform.device, dtype=waveform.dtype) + return mel.permute(0, 1, 3, 2).contiguous() + + class AudioPatchifier(Patchifier): def __init__( self, diff --git a/diffsynth/models/ltx2_dit.py b/diffsynth/models/ltx2_dit.py index 2e3c958..c12a3f3 100644 --- a/diffsynth/models/ltx2_dit.py +++ b/diffsynth/models/ltx2_dit.py @@ -1446,6 +1446,6 @@ class LTXModel(torch.nn.Module): cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) self._init_preprocessors(cross_pe_max_pos) video = Modality(video_latents, video_timesteps, video_positions, video_context) - audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) + audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None vx, ax = self._forward(video=video, audio=audio, perturbations=None) return vx, ax diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 9ed48aa..444e004 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -18,7 +18,7 @@ from ..diffusion.base_pipeline import BasePipeline, PipelineUnit from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer from ..models.ltx2_dit import LTXModel from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier -from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier +from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier, AudioProcessor from ..models.ltx2_upsampler import LTX2LatentUpsampler from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS from ..utils.data.media_io_ltx2 import ltx2_preprocess @@ -50,6 +50,7 @@ class LTX2AudioVideoPipeline(BasePipeline): self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1) self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1) + self.audio_processor: AudioProcessor = AudioProcessor() self.in_iteration_models = ("dit",) self.units = [ @@ -57,6 +58,7 @@ class LTX2AudioVideoPipeline(BasePipeline): LTX2AudioVideoUnit_ShapeChecker(), LTX2AudioVideoUnit_PromptEmbedder(), LTX2AudioVideoUnit_NoiseInitializer(), + LTX2AudioVideoUnit_InputAudioEmbedder(), LTX2AudioVideoUnit_InputVideoEmbedder(), LTX2AudioVideoUnit_InputImagesEmbedder(), ] @@ -95,7 +97,7 @@ class LTX2AudioVideoPipeline(BasePipeline): stage2_lora_config.download_if_necessary() pipe.stage2_lora_path = stage2_lora_config.path # Optional, currently not used - # pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder") + pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder") # VRAM Management pipe.vram_management_enabled = pipe.check_vram_management_state() @@ -157,7 +159,6 @@ class LTX2AudioVideoPipeline(BasePipeline): num_frames=121, # Classifier-free guidance cfg_scale: Optional[float] = 3.0, - cfg_merge: Optional[bool] = False, # Scheduler num_inference_steps: Optional[int] = 40, # VAE tiling @@ -186,7 +187,7 @@ class LTX2AudioVideoPipeline(BasePipeline): "input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength, "seed": seed, "rand_device": rand_device, "height": height, "width": width, "num_frames": num_frames, - "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, + "cfg_scale": cfg_scale, "tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels, "tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames, "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, @@ -422,7 +423,7 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0): video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) - video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels) + video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=128) video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device) @@ -455,17 +456,48 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit): def __init__(self): super().__init__( - input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"), - output_params=("video_latents", "audio_latents"), + input_params=("input_video", "video_noise", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"), + output_params=("video_latents", "input_latents"), onload_model_names=("video_vae_encoder") ) - def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride): + def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled, tile_size_in_pixels, tile_overlap_in_pixels): if input_video is None: - return {"video_latents": video_noise, "audio_latents": audio_noise} + return {"video_latents": video_noise} else: - # TODO: implement video-to-video - raise NotImplementedError("Video-to-video not implemented yet.") + pipe.load_models_to_device(self.onload_model_names) + input_video = pipe.preprocess_video(input_video) + input_latents = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device) + if pipe.scheduler.training: + return {"video_latents": input_latents, "input_latents": input_latents} + else: + # TODO: implement video-to-video + raise NotImplementedError("Video-to-video not implemented yet.") + +class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_audio", "audio_noise"), + output_params=("audio_latents", "audio_input_latents", "audio_positions", "audio_latent_shape"), + onload_model_names=("audio_vae_encoder",) + ) + + def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise): + if input_audio is None: + return {"audio_latents": audio_noise} + else: + input_audio, sample_rate = input_audio + pipe.load_models_to_device(self.onload_model_names) + input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype) + audio_input_latents = pipe.audio_vae_encoder(input_audio) + audio_noise = torch.randn_like(audio_input_latents) + audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape) + audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device) + if pipe.scheduler.training: + return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape} + else: + # TODO: implement video-to-video + raise NotImplementedError("Video-to-video not implemented yet.") class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): def __init__(self): @@ -530,9 +562,12 @@ def model_fn_ltx2( video_timesteps = timestep.repeat(1, video_latents.shape[1], 1) if denoise_mask_video is not None: video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps - _, c_a, _, mel_bins = audio_latents.shape - audio_latents = audio_patchifier.patchify(audio_latents) - audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1) + if audio_latents is not None: + _, c_a, _, mel_bins = audio_latents.shape + audio_latents = audio_patchifier.patchify(audio_latents) + audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1) + else: + audio_timesteps = None #TODO: support gradient checkpointing in training vx, ax = dit( video_latents=video_latents, @@ -546,5 +581,5 @@ def model_fn_ltx2( ) # unpatchify vx = video_patchifier.unpatchify_video(vx, f, h, w) - ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) + ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None return vx, ax diff --git a/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py b/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py index b8e0811..39623dd 100644 --- a/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py +++ b/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py @@ -19,7 +19,12 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer_distilled.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py b/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py index 1614c1a..c50ff5f 100644 --- a/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py +++ b/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py @@ -19,7 +19,12 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), ) diff --git a/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py index e73ef3d..bd86b34 100644 --- a/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py +++ b/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py @@ -19,7 +19,12 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py index c1dc94b..da454dc 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py index f6b3f0a..e6dd9ba 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py index 6f8fd72..a84efe9 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py index 2de3233..0a2e968 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py index 571fd6b..4adfdec 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py index 18905fe..6dfa664 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py index ffa9b38..f351789 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py b/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py index 2b87dd3..5b57aea 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer_distilled.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py index ade78d0..2f73f80 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), ) @@ -40,3 +44,12 @@ write_video_audio_ltx2( fps=24, audio_sample_rate=24000, ) + + + # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), + # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors", **vram_config), + # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), \ No newline at end of file diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py index 84bbc0c..065e92c 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py b/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py index 7020b40..370e371 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py @@ -19,7 +19,12 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer_distilled.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py b/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py index 48ca23b..a42a0bb 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py @@ -19,7 +19,12 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py b/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py index 5411b8c..912fa93 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py @@ -19,7 +19,12 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py index b15e4cf..9f6af9f 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py index 4a7a5aa..8ee5fa9 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py index 9ae6884..3bb9f50 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py index ab9f9ae..0301445 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py index 9fc6e41..d1a8db3 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py index 628e7c3..0ae8041 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py index b6394bc..f8130ae 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py index d8b6a5d..836077b 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer_distilled.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py index 894c417..fb08384 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py index 65650d0..f513f46 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py @@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh b/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh new file mode 100644 index 0000000..8dfdd4b --- /dev/null +++ b/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh @@ -0,0 +1,35 @@ +# Splited Training +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio" \ + --height 512 \ + --width 768 \ + --num_frames 49 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2-T2AV-full-splited-cache" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --task "sft:data_process" + +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path ./models/train/LTX2-T2AV-full-splited-cache \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio" \ + --height 512 \ + --width 768 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2-T2AV-full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --task "sft:train" diff --git a/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh b/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh new file mode 100644 index 0000000..4f1e754 --- /dev/null +++ b/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh @@ -0,0 +1,56 @@ +# single stage training +# accelerate launch examples/ltx2/model_training/train.py \ +# --dataset_base_path data/example_video_dataset/ltx2 \ +# --dataset_metadata_path data/example_video_dataset/ltx2_t2v.csv \ +# --height 256 \ +# --width 384 \ +# --num_frames 25\ +# --dataset_repeat 100 \ +# --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors,DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ +# --learning_rate 1e-4 \ +# --num_epochs 5 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/LTX2-T2AV-noaudio_lora" \ +# --lora_base_model "dit" \ +# --lora_target_modules "to_k,to_q,to_v,to_out.0" \ +# --lora_rank 32 \ +# --use_gradient_checkpointing \ +# --find_unused_parameters + + +# Splited Training +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --height 256 \ + --width 384 \ + --num_frames 49\ + --dataset_repeat 1 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2-T2AV-noaudio_lora-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_k,to_q,to_v,to_out.0" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --task "sft:data_process" + + +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path ./models/train/LTX2-T2AV-noaudio_lora-splited-cache \ + --height 256 \ + --width 384 \ + --num_frames 49\ + --dataset_repeat 100 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2-T2AV-noaudio_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_k,to_q,to_v,to_out.0" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --task "sft:train" diff --git a/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh b/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh new file mode 100644 index 0000000..e71494a --- /dev/null +++ b/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh @@ -0,0 +1,40 @@ +# Splited Training +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio" \ + --height 512 \ + --width 768 \ + --num_frames 49 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2-T2AV_lora-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_k,to_q,to_v,to_out.0" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --task "sft:data_process" + + +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path ./models/train/LTX2-T2AV_lora-splited-cache \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio" \ + --height 512 \ + --width 768 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2-T2AV_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_k,to_q,to_v,to_out.0" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --task "sft:train" diff --git a/examples/ltx2/model_training/lora/LTX-2-T2AV.sh b/examples/ltx2/model_training/lora/LTX-2-T2AV.sh new file mode 100644 index 0000000..1d8aaaf --- /dev/null +++ b/examples/ltx2/model_training/lora/LTX-2-T2AV.sh @@ -0,0 +1,19 @@ +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio" \ + --height 256 \ + --width 384 \ + --num_frames 25\ + --dataset_repeat 100 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors,DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2-T2AV_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_k,to_q,to_v,to_out.0" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/ltx2/model_training/train.py b/examples/ltx2/model_training/train.py new file mode 100644 index 0000000..729b4ed --- /dev/null +++ b/examples/ltx2/model_training/train.py @@ -0,0 +1,162 @@ +import torch, os, argparse, accelerate, warnings +from diffsynth.core import UnifiedDataset +from diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.diffusion import * +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class LTX2TrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + fp8_models=None, + offload_models=None, + device="cpu", + task="sft", + ): + super().__init__() + # Warning + if not use_gradient_checkpointing: + warnings.warn("Gradient checkpointing is detected as disabled. To prevent out-of-memory errors, the training framework will forcibly enable gradient checkpointing.") + use_gradient_checkpointing = True + + # Load models + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) + tokenizer_config = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized") if tokenizer_path is None else ModelConfig(tokenizer_path) + self.pipe = LTX2AudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) + self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) + # Training mode + self.switch_pipe_to_training_mode( + self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, + preset_lora_path, preset_lora_model, + task=task, + ) + + # Store other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.fp8_models = fp8_models + self.task = task + self.task_to_loss = { + "sft:data_process": lambda pipe, *args: args, + "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi), + "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi), + } + + def parse_extra_inputs(self, data, extra_inputs, inputs_shared): + for extra_input in extra_inputs: + inputs_shared[extra_input] = data[extra_input] + return inputs_shared + + def get_pipeline_inputs(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {} + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_video": data["video"], + "height": data["video"][0].size[1], + "width": data["video"][0].size[0], + "num_frames": len(data["video"]), + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "tiled": False, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + "video_patchifier": self.pipe.video_patchifier, + "audio_patchifier": self.pipe.audio_patchifier, + } + inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) + return inputs_shared, inputs_posi, inputs_nega + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.get_pipeline_inputs(data) + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + for unit in self.pipe.units: + inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) + loss = self.task_to_loss[self.task](self.pipe, *inputs) + return loss + + +def ltx2_parser(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser = add_general_config(parser) + parser = add_video_size_config(parser) + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + parser.add_argument("--frame_rate", type=float, default=24, help="frame rate of the training videos. If not specified, it will be determined by the dataset.") + parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") + return parser + + +if __name__ == "__main__": + parser = ltx2_parser() + args = parser.parse_args() + accelerator = accelerate.Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], + ) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_video_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + num_frames=args.num_frames, + time_division_factor=4, + time_division_remainder=1, + ), + special_operator_map={ + "input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(duration=float(args.num_frames) / float(args.frame_rate)), + } + ) + model = LTX2TrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_path=args.tokenizer_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + preset_lora_path=args.preset_lora_path, + preset_lora_model=args.preset_lora_model, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + fp8_models=args.fp8_models, + offload_models=args.offload_models, + task=args.task, + device="cpu" if args.initialize_model_on_cpu else accelerator.device, + ) + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + ) + launcher_map = { + "sft:data_process": launch_data_process_task, + "direct_distill:data_process": launch_data_process_task, + "sft": launch_training_task, + "sft:train": launch_training_task, + "direct_distill": launch_training_task, + "direct_distill:train": launch_training_task, + } + launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) diff --git a/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py b/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py new file mode 100644 index 0000000..6201ec1 --- /dev/null +++ b/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py @@ -0,0 +1,47 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(path="./models/train/LTX2-T2AV-full/epoch-4.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) +prompt = "A beautiful sunset over the ocean." +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + cfg_scale=4.0 +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py new file mode 100644 index 0000000..4372b6c --- /dev/null +++ b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py @@ -0,0 +1,49 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) +pipe.load_lora(pipe.dit, "models/train/LTX2-T2AV_lora/epoch-4.safetensors") +prompt = "A beautiful sunset over the ocean." +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 121 +height, width, num_frames = 256, 384, 25 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + cfg_scale=4.0 +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py new file mode 100644 index 0000000..b35c23c --- /dev/null +++ b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py @@ -0,0 +1,49 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) +pipe.load_lora(pipe.dit, "models/train/LTX2-T2AV-noaudio_lora/epoch-4.safetensors") +prompt = "A beautiful sunset over the ocean." +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 121 +height, width, num_frames = 256, 384, 25 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + cfg_scale=4.0 +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/scripts/split_model_statedicts.py b/examples/ltx2/scripts/split_model_statedicts.py new file mode 100644 index 0000000..1ca2ff3 --- /dev/null +++ b/examples/ltx2/scripts/split_model_statedicts.py @@ -0,0 +1,104 @@ +from safetensors.torch import save_file +from diffsynth import hash_state_dict_keys +from diffsynth.core import load_state_dict +from diffsynth.models.model_loader import ModelPool + +model_pool = ModelPool() +state_dict = load_state_dict("models/Lightricks/LTX-2/ltx-2-19b-dev.safetensors") + +dit_state_dict = {} +for name in state_dict: + if name.startswith("model.diffusion_model."): + new_name = name.replace("model.diffusion_model.", "") + if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."): + continue + dit_state_dict[name] = state_dict[name] + +print(f"dit_state_dict keys hash: {hash_state_dict_keys(dit_state_dict)}") +save_file(dit_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/transformer.safetensors") +model_pool.auto_load_model( + "models/DiffSynth-Studio/LTX-2-Repackage/transformer.safetensors", +) + + +video_vae_encoder_state_dict = {} +for name in state_dict: + if name.startswith("vae.encoder."): + video_vae_encoder_state_dict[name] = state_dict[name] + elif name.startswith("vae.per_channel_statistics."): + video_vae_encoder_state_dict[name] = state_dict[name] + +save_file(video_vae_encoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/video_vae_encoder.safetensors") +print(f"video_vae_encoder keys hash: {hash_state_dict_keys(video_vae_encoder_state_dict)}") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/video_vae_encoder.safetensors") + + +video_vae_decoder_state_dict = {} +for name in state_dict: + if name.startswith("vae.decoder."): + video_vae_decoder_state_dict[name] = state_dict[name] + elif name.startswith("vae.per_channel_statistics."): + video_vae_decoder_state_dict[name] = state_dict[name] +save_file(video_vae_decoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/video_vae_decoder.safetensors") +print(f"video_vae_decoder keys hash: {hash_state_dict_keys(video_vae_decoder_state_dict)}") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/video_vae_decoder.safetensors") + + +audio_vae_decoder_state_dict = {} +for name in state_dict: + if name.startswith("audio_vae.decoder."): + audio_vae_decoder_state_dict[name] = state_dict[name] + elif name.startswith("audio_vae.per_channel_statistics."): + audio_vae_decoder_state_dict[name] = state_dict[name] +save_file(audio_vae_decoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_decoder.safetensors") +print(f"audio_vae_decoder keys hash: {hash_state_dict_keys(audio_vae_decoder_state_dict)}") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_decoder.safetensors") + + +audio_vae_encoder_state_dict = {} +for name in state_dict: + if name.startswith("audio_vae.encoder."): + audio_vae_encoder_state_dict[name] = state_dict[name] + elif name.startswith("audio_vae.per_channel_statistics."): + audio_vae_encoder_state_dict[name] = state_dict[name] +save_file(audio_vae_encoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_encoder.safetensors") +print(f"audio_vae_encoder keys hash: {hash_state_dict_keys(audio_vae_encoder_state_dict)}") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_encoder.safetensors") + + +audio_vocoder_state_dict = {} +for name in state_dict: + if name.startswith("vocoder."): + audio_vocoder_state_dict[name] = state_dict[name] +save_file(audio_vocoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/audio_vocoder.safetensors") +print(f"audio_vocoder keys hash: {hash_state_dict_keys(audio_vocoder_state_dict)}") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/audio_vocoder.safetensors") + + +text_encoder_post_modules_state_dict = {} +for name in state_dict: + if name.startswith("text_embedding_projection."): + text_encoder_post_modules_state_dict[name] = state_dict[name] + elif name.startswith("model.diffusion_model.video_embeddings_connector."): + text_encoder_post_modules_state_dict[name] = state_dict[name] + elif name.startswith("model.diffusion_model.audio_embeddings_connector."): + text_encoder_post_modules_state_dict[name] = state_dict[name] +save_file(text_encoder_post_modules_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/text_encoder_post_modules.safetensors") +print(f"text_encoder_post_modules keys hash: {hash_state_dict_keys(text_encoder_post_modules_state_dict)}") +model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/text_encoder_post_modules.safetensors") + + +state_dict = load_state_dict("models/Lightricks/LTX-2/ltx-2-19b-distilled.safetensors") +dit_state_dict = {} +for name in state_dict: + if name.startswith("model.diffusion_model."): + new_name = name.replace("model.diffusion_model.", "") + if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."): + continue + dit_state_dict[name] = state_dict[name] + +print(f"dit_state_dict keys hash: {hash_state_dict_keys(dit_state_dict)}") +save_file(dit_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/transformer_distilled.safetensors") +model_pool.auto_load_model( + "models/DiffSynth-Studio/LTX-2-Repackage/transformer_distilled.safetensors", +) \ No newline at end of file From 8e15dcd2894d6acbecb72ed717009f89b9ac7c7e Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 25 Feb 2026 18:06:02 +0800 Subject: [PATCH 09/23] support ltx2 train -2 --- README.md | 4 +- README_zh.md | 4 +- diffsynth/core/data/operators.py | 17 ++++++ docs/en/Model_Details/LTX-2.md | 55 ++++++++++++++++++- docs/zh/Model_Details/LTX-2.md | 55 ++++++++++++++++++- .../LTX-2-T2AV-Camera-Control-Dolly-In.py | 1 + .../LTX-2-T2AV-Camera-Control-Dolly-Left.py | 1 + .../LTX-2-T2AV-Camera-Control-Dolly-Out.py | 1 + .../LTX-2-T2AV-Camera-Control-Dolly-Right.py | 1 + .../LTX-2-T2AV-Camera-Control-Jib-Down.py | 1 + .../LTX-2-T2AV-Camera-Control-Jib-Up.py | 1 + .../LTX-2-T2AV-Camera-Control-Static.py | 1 + .../LTX-2-T2AV-DistilledPipeline.py | 1 + .../model_inference/LTX-2-T2AV-TwoStage.py | 1 + .../LTX-2-T2AV-Camera-Control-Dolly-In.py | 1 + .../LTX-2-T2AV-Camera-Control-Dolly-Left.py | 1 + .../LTX-2-T2AV-Camera-Control-Dolly-Out.py | 1 + .../LTX-2-T2AV-Camera-Control-Dolly-Right.py | 1 + .../LTX-2-T2AV-Camera-Control-Jib-Down.py | 1 + .../LTX-2-T2AV-Camera-Control-Jib-Up.py | 1 + .../LTX-2-T2AV-Camera-Control-Static.py | 1 + .../LTX-2-T2AV-DistilledPipeline.py | 1 + .../LTX-2-T2AV-TwoStage.py | 1 + .../model_training/full/LTX-2-T2AV-splited.sh | 2 +- .../model_training/lora/LTX-2-T2AV-noaudio.sh | 8 +-- .../model_training/lora/LTX-2-T2AV-splited.sh | 22 +++++++- .../ltx2/model_training/lora/LTX-2-T2AV.sh | 19 ------- .../scripts/split_model_statedicts.py | 0 examples/ltx2/model_training/train.py | 2 +- .../validate_full/LTX-2-T2AV.py | 2 +- .../validate_lora/LTX-2-T2AV.py | 3 +- .../validate_lora/LTX-2-T2AV_noaudio.py | 3 +- 32 files changed, 175 insertions(+), 39 deletions(-) delete mode 100644 examples/ltx2/model_training/lora/LTX-2-T2AV.sh rename examples/ltx2/{ => model_training}/scripts/split_model_statedicts.py (100%) diff --git a/README.md b/README.md index 2bc45e6..f266af1 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ We believe that a well-developed open-source code framework can lower the thresh > DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update. > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. -- **February 10, 2026** Added inference support for the LTX-2 audio-video generation model. See the documentation for details. Support for model training will be implemented in the future. +- **February 10, 2026** Added inference and training support for the LTX-2 audio-video generation model. See the documentation for details. - **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models. @@ -614,7 +614,7 @@ Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/) | Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)| |[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| diff --git a/README_zh.md b/README_zh.md index 72c80d5..93ec2cf 100644 --- a/README_zh.md +++ b/README_zh.md @@ -32,7 +32,7 @@ DiffSynth 目前包括两个开源项目: > DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。 > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 -- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持,详见[文档](docs/zh/Model_Details/LTX-2.md),后续将推进模型训练的支持。 +- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理和训练支持,详见[文档](docs/zh/Model_Details/LTX-2.md)。 - **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。 @@ -614,7 +614,7 @@ LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/) |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)| |[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| diff --git a/diffsynth/core/data/operators.py b/diffsynth/core/data/operators.py index b7a0e7e..36756d3 100644 --- a/diffsynth/core/data/operators.py +++ b/diffsynth/core/data/operators.py @@ -218,3 +218,20 @@ class LoadAudio(DataProcessingOperator): import librosa input_audio, sample_rate = librosa.load(data, sr=self.sr) return input_audio + + +class LoadAudioWithTorchaudio(DataProcessingOperator): + def __init__(self, duration=5): + self.duration = duration + + def __call__(self, data: str): + import torchaudio + waveform, sample_rate = torchaudio.load(data) + target_samples = int(self.duration * sample_rate) + current_samples = waveform.shape[-1] + if current_samples > target_samples: + waveform = waveform[..., :target_samples] + elif current_samples < target_samples: + padding = target_samples - current_samples + waveform = torch.nn.functional.pad(waveform, (0, padding)) + return waveform, sample_rate diff --git a/docs/en/Model_Details/LTX-2.md b/docs/en/Model_Details/LTX-2.md index c285a7f..e3f4ca8 100644 --- a/docs/en/Model_Details/LTX-2.md +++ b/docs/en/Model_Details/LTX-2.md @@ -67,7 +67,7 @@ write_video_audio_ltx2( ## Model Overview |Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training| |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)| |[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| @@ -113,4 +113,55 @@ If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_ ## Model Training -The LTX-2 series models currently do not support training functionality. We will add related support as soon as possible. +LTX-2 series models are uniformly trained through [`examples/ltx2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/train.py), and the script parameters include: + +* General Training Parameters + * Dataset Basic Configuration + * `--dataset_base_path`: Root directory of the dataset. + * `--dataset_metadata_path`: Metadata file path of the dataset. + * `--dataset_repeat`: Number of times the dataset is repeated in each epoch. + * `--dataset_num_workers`: Number of processes for each DataLoader. + * `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`. + * Model Loading Configuration + * `--model_paths`: Paths of models to be loaded. JSON format. + * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors"`. Separated by commas. + * `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters when training image editing models, separated by `,`. + * `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA). + * Training Basic Configuration + * `--learning_rate`: Learning rate. + * `--num_epochs`: Number of epochs. + * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`. + * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training. + * `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html). + * `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model. + * Output Configuration + * `--output_path`: Model saving path. + * `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file. + * `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch. + * LoRA Configuration + * `--lora_base_model`: Which model to add LoRA to. + * `--lora_target_modules`: Which layers to add LoRA to. + * `--lora_rank`: Rank of LoRA. + * `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint. + * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training. + * `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`. + * Gradient Configuration + * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing. + * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory. + * `--gradient_accumulation_steps`: Number of gradient accumulation steps. + * Video Width/Height Configuration + * `--height`: Height of the video. Leave `height` and `width` blank to enable dynamic resolution. + * `--width`: Width of the video. Leave `height` and `width` blank to enable dynamic resolution. + * `--max_pixels`: Maximum pixel area of video frames. When dynamic resolution is enabled, video frames with resolution larger than this value will be downscaled, and video frames with resolution smaller than this value will remain unchanged. + * `--num_frames`: Number of frames in the video. +* LTX-2 Series Specific Parameters + * `--tokenizer_path`: Path of the tokenizer, applicable to text-to-video models, leave blank to automatically download from remote. + * `--frame_rate`: frame rate of the training videos. + +We have built a sample video dataset for your testing. You can download this dataset with the following command: + +```shell +modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset +``` + +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/). diff --git a/docs/zh/Model_Details/LTX-2.md b/docs/zh/Model_Details/LTX-2.md index 6961931..86abbcd 100644 --- a/docs/zh/Model_Details/LTX-2.md +++ b/docs/zh/Model_Details/LTX-2.md @@ -67,7 +67,7 @@ write_video_audio_ltx2( ## 模型总览 |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)| |[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| @@ -113,4 +113,55 @@ write_video_audio_ltx2( ## 模型训练 -LTX-2 系列模型目前暂不支持训练功能。我们将尽快添加相关支持。 +LTX-2 系列模型统一通过 [`examples/ltx2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/train.py) 进行训练,脚本的参数包括: + +* 通用训练参数 + * 数据集基础配置 + * `--dataset_base_path`: 数据集的根目录。 + * `--dataset_metadata_path`: 数据集的元数据文件路径。 + * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 + * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。 + * 模型加载配置 + * `--model_paths`: 要加载的模型路径。JSON 格式。 + * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors"`。用逗号分隔。 + * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练图像编辑模型时需要额外参数,以 `,` 分隔。 + * `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。 + * 训练基础配置 + * `--learning_rate`: 学习率。 + * `--num_epochs`: 轮数(Epoch)。 + * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。 + * `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。 + * `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。 + * 输出配置 + * `--output_path`: 模型保存路径。 + * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。 + * `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。 + * LoRA 配置 + * `--lora_base_model`: LoRA 添加到哪个模型上。 + * `--lora_target_modules`: LoRA 添加到哪些层上。 + * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 + * `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。 + * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。 + * 梯度配置 + * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。 + * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 + * `--gradient_accumulation_steps`: 梯度累积步数。 + * 视频宽高配置 + * `--height`: 视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--width`: 视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--max_pixels`: 视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的视频帧都会被缩小,分辨率小于这个数值的视频帧保持不变。 + * `--num_frames`: 视频的帧数。 +* LTX-2 系列特定参数 + * `--tokenizer_path`: 分词器路径,适用于文生视频模型,留空则从远程自动下载。 + * `--frame_rate`: 训练视频的帧率。 + +我们构建了一个样例视频数据集,以方便您进行测试,通过以下命令可以下载这个数据集: + +```shell +modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset +``` + +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。 diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py index da454dc..9932804 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py index e6dd9ba..4f0f273 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py index a84efe9..3fdd2c3 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py index 0a2e968..b4fd099 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py index 4adfdec..1eb2ace 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py index 6dfa664..cc6ac39 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py index f351789..2299bf7 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py b/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py index 5b57aea..78b537b 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py index 065e92c..9a85a9f 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py index 9f6af9f..0454dfd 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py index 8ee5fa9..e66a203 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py index 3bb9f50..94af241 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py index 0301445..b713170 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py index d1a8db3..fe1423c 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py index 0ae8041..b6229c1 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py index f8130ae..13d8ea9 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py index 836077b..daf82f2 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py index f513f46..a10b6ef 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py @@ -22,6 +22,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), diff --git a/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh b/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh index 8dfdd4b..4aec5f8 100644 --- a/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh +++ b/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh @@ -17,7 +17,7 @@ accelerate launch examples/ltx2/model_training/train.py \ --use_gradient_checkpointing \ --task "sft:data_process" -accelerate launch examples/ltx2/model_training/train.py \ +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/ltx2/model_training/train.py \ --dataset_base_path ./models/train/LTX2-T2AV-full-splited-cache \ --data_file_keys "video,input_audio" \ --extra_inputs "input_audio" \ diff --git a/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh b/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh index 4f1e754..b2a5609 100644 --- a/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh +++ b/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh @@ -22,8 +22,8 @@ accelerate launch examples/ltx2/model_training/train.py \ --dataset_base_path data/example_video_dataset/ltx2 \ --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ - --height 256 \ - --width 384 \ + --height 512 \ + --width 768 \ --num_frames 49\ --dataset_repeat 1 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ @@ -40,8 +40,8 @@ accelerate launch examples/ltx2/model_training/train.py \ accelerate launch examples/ltx2/model_training/train.py \ --dataset_base_path ./models/train/LTX2-T2AV-noaudio_lora-splited-cache \ - --height 256 \ - --width 384 \ + --height 512 \ + --width 768 \ --num_frames 49\ --dataset_repeat 100 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \ diff --git a/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh b/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh index e71494a..40dae1a 100644 --- a/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh +++ b/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh @@ -1,3 +1,24 @@ +# Single Stage Training not recommended for T2AV due to the large memory consumption. Please use the Splited Training instead. +# accelerate launch examples/ltx2/model_training/train.py \ +# --dataset_base_path data/example_video_dataset/ltx2 \ +# --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ +# --data_file_keys "video,input_audio" \ +# --extra_inputs "input_audio" \ +# --height 256 \ +# --width 384 \ +# --num_frames 25\ +# --dataset_repeat 100 \ +# --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors,DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ +# --learning_rate 1e-4 \ +# --num_epochs 5 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/LTX2-T2AV_lora" \ +# --lora_base_model "dit" \ +# --lora_target_modules "to_k,to_q,to_v,to_out.0" \ +# --lora_rank 32 \ +# --use_gradient_checkpointing \ +# --find_unused_parameters + # Splited Training accelerate launch examples/ltx2/model_training/train.py \ --dataset_base_path data/example_video_dataset/ltx2 \ @@ -19,7 +40,6 @@ accelerate launch examples/ltx2/model_training/train.py \ --use_gradient_checkpointing \ --task "sft:data_process" - accelerate launch examples/ltx2/model_training/train.py \ --dataset_base_path ./models/train/LTX2-T2AV_lora-splited-cache \ --data_file_keys "video,input_audio" \ diff --git a/examples/ltx2/model_training/lora/LTX-2-T2AV.sh b/examples/ltx2/model_training/lora/LTX-2-T2AV.sh deleted file mode 100644 index 1d8aaaf..0000000 --- a/examples/ltx2/model_training/lora/LTX-2-T2AV.sh +++ /dev/null @@ -1,19 +0,0 @@ -accelerate launch examples/ltx2/model_training/train.py \ - --dataset_base_path data/example_video_dataset/ltx2 \ - --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ - --data_file_keys "video,input_audio" \ - --extra_inputs "input_audio" \ - --height 256 \ - --width 384 \ - --num_frames 25\ - --dataset_repeat 100 \ - --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors,DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/LTX2-T2AV_lora" \ - --lora_base_model "dit" \ - --lora_target_modules "to_k,to_q,to_v,to_out.0" \ - --lora_rank 32 \ - --use_gradient_checkpointing \ - --find_unused_parameters diff --git a/examples/ltx2/scripts/split_model_statedicts.py b/examples/ltx2/model_training/scripts/split_model_statedicts.py similarity index 100% rename from examples/ltx2/scripts/split_model_statedicts.py rename to examples/ltx2/model_training/scripts/split_model_statedicts.py diff --git a/examples/ltx2/model_training/train.py b/examples/ltx2/model_training/train.py index 729b4ed..a994f7a 100644 --- a/examples/ltx2/model_training/train.py +++ b/examples/ltx2/model_training/train.py @@ -96,7 +96,7 @@ def ltx2_parser(): parser = add_general_config(parser) parser = add_video_size_config(parser) parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") - parser.add_argument("--frame_rate", type=float, default=24, help="frame rate of the training videos. If not specified, it will be determined by the dataset.") + parser.add_argument("--frame_rate", type=float, default=24, help="frame rate of the training videos.") parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") return parser diff --git a/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py b/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py index 6201ec1..a5da12d 100644 --- a/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py +++ b/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py @@ -27,7 +27,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ) prompt = "A beautiful sunset over the ocean." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -height, width, num_frames = 512, 768, 121 +height, width, num_frames = 512, 768, 49 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, diff --git a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py index 4372b6c..471a901 100644 --- a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py +++ b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py @@ -28,8 +28,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( pipe.load_lora(pipe.dit, "models/train/LTX2-T2AV_lora/epoch-4.safetensors") prompt = "A beautiful sunset over the ocean." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -height, width, num_frames = 512, 768, 121 -height, width, num_frames = 256, 384, 25 +height, width, num_frames = 512, 768, 49 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, diff --git a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py index b35c23c..4c2bccc 100644 --- a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py +++ b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py @@ -28,8 +28,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( pipe.load_lora(pipe.dit, "models/train/LTX2-T2AV-noaudio_lora/epoch-4.safetensors") prompt = "A beautiful sunset over the ocean." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -height, width, num_frames = 512, 768, 121 -height, width, num_frames = 256, 384, 25 +height, width, num_frames = 512, 768, 49 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, From 8d8bfc7f540e1493f7d6d229872b7186f7648bad Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 25 Feb 2026 19:04:10 +0800 Subject: [PATCH 10/23] minor fix --- README.md | 2 +- README_zh.md | 2 +- diffsynth/pipelines/ltx2_audio_video.py | 1 - examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py | 9 --------- 4 files changed, 2 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index f266af1..e055ae7 100644 --- a/README.md +++ b/README.md @@ -614,7 +614,7 @@ Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/) | Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)| +|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)| |[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| diff --git a/README_zh.md b/README_zh.md index 93ec2cf..519ebe6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -614,7 +614,7 @@ LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/) |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)| +|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)| |[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 444e004..4587449 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -490,7 +490,6 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit): pipe.load_models_to_device(self.onload_model_names) input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype) audio_input_latents = pipe.audio_vae_encoder(input_audio) - audio_noise = torch.randn_like(audio_input_latents) audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape) audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device) if pipe.scheduler.training: diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py index 2f73f80..295d73b 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py @@ -44,12 +44,3 @@ write_video_audio_ltx2( fps=24, audio_sample_rate=24000, ) - - - # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), - # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), - # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), - # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), - # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), - # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors", **vram_config), - # ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), \ No newline at end of file From f48662e8637f29147e593e2ff6b6f1d31a1b1b87 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 26 Feb 2026 11:10:00 +0800 Subject: [PATCH 11/23] update docs --- README.md | 34 +++++++++++- README_zh.md | 34 +++++++++++- diffsynth/configs/model_configs.py | 16 +++++- docs/en/Model_Details/LTX-2.md | 52 ++++++++++++++++-- docs/zh/Model_Details/LTX-2.md | 54 +++++++++++++++++-- .../model_inference/LTX-2-T2AV-OneStage.py | 19 +++++++ .../model_inference/LTX-2-T2AV-TwoStage.py | 22 +++++++- .../LTX-2-T2AV-OneStage.py | 20 +++++++ .../LTX-2-T2AV-TwoStage.py | 22 ++++++++ 9 files changed, 258 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index e055ae7..5fe2017 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,9 @@ We believe that a well-developed open-source code framework can lower the thresh > DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update. > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. -- **February 10, 2026** Added inference and training support for the LTX-2 audio-video generation model. See the documentation for details. +- **February 26, 2026** Added full and lora training support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details. + +- **February 10, 2026** Added inference support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details. Support for model training will be implemented in the future. - **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models. @@ -557,12 +559,26 @@ vram_config = { "computation_dtype": torch.bfloat16, "computation_device": "cuda", } +""" +Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2 +Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage +For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")) +and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported. +We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules, +and avoid redundant memory usage when users only want to use part of the model. +""" +# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading pipe = LTX2AudioVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), @@ -570,6 +586,20 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) +# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2" +# pipe = LTX2AudioVideoPipeline.from_pretrained( +# torch_dtype=torch.bfloat16, +# device="cuda", +# model_configs=[ +# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), +# ], +# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +# ) + prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\"" negative_prompt = ( "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " diff --git a/README_zh.md b/README_zh.md index 519ebe6..74843db 100644 --- a/README_zh.md +++ b/README_zh.md @@ -32,7 +32,9 @@ DiffSynth 目前包括两个开源项目: > DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。 > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 -- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理和训练支持,详见[文档](docs/zh/Model_Details/LTX-2.md)。 +- **2026年2月26日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型全量微调与LoRA训练支持,详见[文档](docs/zh/Model_Details/LTX-2.md)。 + +- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持,详见[文档](docs/zh/Model_Details/LTX-2.md),后续将推进模型训练的支持。 - **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。 @@ -557,12 +559,26 @@ vram_config = { "computation_dtype": torch.bfloat16, "computation_device": "cuda", } +""" +Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2 +Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage +For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")) +and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported. +We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules, +and avoid redundant memory usage when users only want to use part of the model. +""" +# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading pipe = LTX2AudioVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), @@ -570,6 +586,20 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) +# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2" +# pipe = LTX2AudioVideoPipeline.from_pretrained( +# torch_dtype=torch.bfloat16, +# device="cuda", +# model_configs=[ +# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), +# ], +# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +# ) + prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\"" negative_prompt = ( "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index dbad638..fbca133 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -598,7 +598,14 @@ z_image_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter", }, ] - +""" +Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2 +Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage +For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")) +and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported. +We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules, +and avoid redundant memory usage when users only want to use part of the model. +""" ltx2_series = [ { # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") @@ -608,6 +615,7 @@ ltx2_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter", }, { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors") "model_hash": "c567aaa37d5ed7454c73aa6024458661", "model_name": "ltx2_dit", "model_class": "diffsynth.models.ltx2_dit.LTXModel", @@ -621,6 +629,7 @@ ltx2_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter", }, { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors") "model_hash": "7f7e904a53260ec0351b05f32153754b", "model_name": "ltx2_video_vae_encoder", "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder", @@ -634,6 +643,7 @@ ltx2_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter", }, { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors") "model_hash": "dc6029ca2825147872b45e35a2dc3a97", "model_name": "ltx2_video_vae_decoder", "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder", @@ -647,6 +657,7 @@ ltx2_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter", }, { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors") "model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb", "model_name": "ltx2_audio_vae_decoder", "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder", @@ -660,6 +671,7 @@ ltx2_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter", }, { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors") "model_hash": "f471360f6b24bef702ab73133d9f8bb9", "model_name": "ltx2_audio_vocoder", "model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder", @@ -673,6 +685,7 @@ ltx2_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter", }, { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors") "model_hash": "29338f3b95e7e312a3460a482e4f4554", "model_name": "ltx2_audio_vae_encoder", "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder", @@ -686,6 +699,7 @@ ltx2_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter", }, { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors") "model_hash": "981629689c8be92a712ab3c5eb4fc3f6", "model_name": "ltx2_text_encoder_post_modules", "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules", diff --git a/docs/en/Model_Details/LTX-2.md b/docs/en/Model_Details/LTX-2.md index e3f4ca8..68ab351 100644 --- a/docs/en/Model_Details/LTX-2.md +++ b/docs/en/Model_Details/LTX-2.md @@ -33,19 +33,62 @@ vram_config = { "computation_dtype": torch.bfloat16, "computation_device": "cuda", } +""" +Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2 +Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage +For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")) +and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported. +We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules, +and avoid redundant memory usage when users only want to use part of the model. +""" +# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading pipe = LTX2AudioVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) + +# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2" +# pipe = LTX2AudioVideoPipeline.from_pretrained( +# torch_dtype=torch.bfloat16, +# device="cuda", +# model_configs=[ +# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), +# ], +# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +# ) + prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\"" -negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -height, width, num_frames = 512, 768, 121 +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, @@ -54,11 +97,12 @@ video, audio = pipe( width=width, num_frames=num_frames, tiled=True, + use_two_stage_pipeline=True, ) write_video_audio_ltx2( video=video, audio=audio, - output_path='ltx2_onestage.mp4', + output_path='ltx2_twostage.mp4', fps=24, audio_sample_rate=24000, ) diff --git a/docs/zh/Model_Details/LTX-2.md b/docs/zh/Model_Details/LTX-2.md index 86abbcd..558de9d 100644 --- a/docs/zh/Model_Details/LTX-2.md +++ b/docs/zh/Model_Details/LTX-2.md @@ -33,19 +33,62 @@ vram_config = { "computation_dtype": torch.bfloat16, "computation_device": "cuda", } +""" +Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2 +Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage +For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")) +and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported. +We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules, +and avoid redundant memory usage when users only want to use part of the model. +""" +# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading pipe = LTX2AudioVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), - ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) -prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" -negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -height, width, num_frames = 512, 768, 121 + +# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2" +# pipe = LTX2AudioVideoPipeline.from_pretrained( +# torch_dtype=torch.bfloat16, +# device="cuda", +# model_configs=[ +# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), +# ], +# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +# ) + +prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\"" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, @@ -54,11 +97,12 @@ video, audio = pipe( width=width, num_frames=num_frames, tiled=True, + use_two_stage_pipeline=True, ) write_video_audio_ltx2( video=video, audio=audio, - output_path='ltx2_onestage.mp4', + output_path='ltx2_twostage.mp4', fps=24, audio_sample_rate=24000, ) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py index 295d73b..9e56209 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py @@ -12,6 +12,15 @@ vram_config = { "computation_dtype": torch.bfloat16, "computation_device": "cuda", } +""" +Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2 +Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage +For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")) +and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported. +We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules, +and avoid redundant memory usage when users only want to use part of the model. +""" +# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading pipe = LTX2AudioVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", @@ -25,6 +34,16 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), ) +# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2" +# pipe = LTX2AudioVideoPipeline.from_pretrained( +# torch_dtype=torch.bfloat16, +# device="cuda", +# model_configs=[ +# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), +# ], +# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +# ) prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." height, width, num_frames = 512, 768, 121 diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py index 9a85a9f..e08ed50 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py @@ -12,6 +12,15 @@ vram_config = { "computation_dtype": torch.bfloat16, "computation_device": "cuda", } +""" +Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2 +Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage +For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")) +and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported. +We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules, +and avoid redundant memory usage when users only want to use part of the model. +""" +# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading pipe = LTX2AudioVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", @@ -28,7 +37,18 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), ) - +# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2" +# pipe = LTX2AudioVideoPipeline.from_pretrained( +# torch_dtype=torch.bfloat16, +# device="cuda", +# model_configs=[ +# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), +# ], +# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +# ) prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" negative_prompt = ( "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py index fb08384..c08332f 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py @@ -12,6 +12,15 @@ vram_config = { "computation_dtype": torch.bfloat16, "computation_device": "cuda", } +""" +Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2 +Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage +For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")) +and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported. +We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules, +and avoid redundant memory usage when users only want to use part of the model. +""" +# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading pipe = LTX2AudioVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", @@ -26,6 +35,17 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) +# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2" +# pipe = LTX2AudioVideoPipeline.from_pretrained( +# torch_dtype=torch.bfloat16, +# device="cuda", +# model_configs=[ +# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), +# ], +# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +# ) prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." height, width, num_frames = 512, 768, 121 diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py index a10b6ef..98ed966 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py @@ -12,6 +12,15 @@ vram_config = { "computation_dtype": torch.bfloat16, "computation_device": "cuda", } +""" +Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2 +Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage +For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")) +and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported. +We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules, +and avoid redundant memory usage when users only want to use part of the model. +""" +# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading pipe = LTX2AudioVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", @@ -29,6 +38,19 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) +# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2" +# pipe = LTX2AudioVideoPipeline.from_pretrained( +# torch_dtype=torch.bfloat16, +# device="cuda", +# model_configs=[ +# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), +# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), +# ], +# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +# ) prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" negative_prompt = ( From a18966c30084fb9877d278deb4c8882d503ee779 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 26 Feb 2026 19:19:59 +0800 Subject: [PATCH 12/23] support ltx2 gradient_checkpointing --- diffsynth/models/ltx2_dit.py | 41 ++++++++++--------- diffsynth/pipelines/ltx2_audio_video.py | 2 + diffsynth/utils/data/__init__.py | 2 +- .../model_training/full/LTX-2-T2AV-splited.sh | 4 +- .../model_training/lora/LTX-2-T2AV-noaudio.sh | 4 +- .../model_training/lora/LTX-2-T2AV-splited.sh | 4 +- examples/ltx2/model_training/train.py | 6 +-- .../validate_full/LTX-2-T2AV.py | 2 +- .../validate_lora/LTX-2-T2AV.py | 2 +- .../validate_lora/LTX-2-T2AV_noaudio.py | 2 +- 10 files changed, 36 insertions(+), 33 deletions(-) diff --git a/diffsynth/models/ltx2_dit.py b/diffsynth/models/ltx2_dit.py index c12a3f3..cc584ee 100644 --- a/diffsynth/models/ltx2_dit.py +++ b/diffsynth/models/ltx2_dit.py @@ -8,6 +8,7 @@ import torch from einops import rearrange from .ltx2_common import rms_norm, Modality from ..core.attention.attention import attention_forward +from ..core import gradient_checkpoint_forward def get_timestep_embedding( @@ -1352,28 +1353,21 @@ class LTXModel(torch.nn.Module): video: TransformerArgs | None, audio: TransformerArgs | None, perturbations: BatchedPerturbationConfig, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, ) -> tuple[TransformerArgs, TransformerArgs]: """Process transformer blocks for LTXAV.""" # Process transformer blocks for block in self.transformer_blocks: - if self._enable_gradient_checkpointing and self.training: - # Use gradient checkpointing to save memory during training. - # With use_reentrant=False, we can pass dataclasses directly - - # PyTorch will track all tensor leaves in the computation graph. - video, audio = torch.utils.checkpoint.checkpoint( - block, - video, - audio, - perturbations, - use_reentrant=False, - ) - else: - video, audio = block( - video=video, - audio=audio, - perturbations=perturbations, - ) + video, audio = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + video=video, + audio=audio, + perturbations=perturbations, + ) return video, audio @@ -1398,7 +1392,12 @@ class LTXModel(torch.nn.Module): return x def _forward( - self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig + self, + video: Modality | None, + audio: Modality | None, + perturbations: BatchedPerturbationConfig, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for LTX models. @@ -1417,6 +1416,8 @@ class LTXModel(torch.nn.Module): video=video_args, audio=audio_args, perturbations=perturbations, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) # Process output @@ -1440,12 +1441,12 @@ class LTXModel(torch.nn.Module): ) return vx, ax - def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps): + def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False): cross_pe_max_pos = None if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) self._init_preprocessors(cross_pe_max_pos) video = Modality(video_latents, video_timesteps, video_positions, video_context) audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None - vx, ax = self._forward(video=video, audio=audio, perturbations=None) + vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload) return vx, ax diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 4587449..fc0b969 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -577,6 +577,8 @@ def model_fn_ltx2( audio_positions=audio_positions, audio_context=audio_context, audio_timesteps=audio_timesteps, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) # unpatchify vx = video_patchifier.unpatchify_video(vx, f, h, w) diff --git a/diffsynth/utils/data/__init__.py b/diffsynth/utils/data/__init__.py index c6b9daa..edc3d41 100644 --- a/diffsynth/utils/data/__init__.py +++ b/diffsynth/utils/data/__init__.py @@ -116,7 +116,7 @@ class VideoData: if self.height is not None and self.width is not None: return self.height, self.width else: - height, width, _ = self.__getitem__(0).shape + width, height = self.__getitem__(0).size return height, width def __getitem__(self, item): diff --git a/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh b/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh index 4aec5f8..04f3b1c 100644 --- a/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh +++ b/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh @@ -6,7 +6,7 @@ accelerate launch examples/ltx2/model_training/train.py \ --extra_inputs "input_audio" \ --height 512 \ --width 768 \ - --num_frames 49 \ + --num_frames 121 \ --dataset_repeat 1 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ --learning_rate 1e-4 \ @@ -23,7 +23,7 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera --extra_inputs "input_audio" \ --height 512 \ --width 768 \ - --num_frames 49 \ + --num_frames 121 \ --dataset_repeat 100 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \ --learning_rate 1e-4 \ diff --git a/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh b/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh index b2a5609..f7362af 100644 --- a/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh +++ b/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh @@ -24,7 +24,7 @@ accelerate launch examples/ltx2/model_training/train.py \ --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ --height 512 \ --width 768 \ - --num_frames 49\ + --num_frames 121\ --dataset_repeat 1 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ --learning_rate 1e-4 \ @@ -42,7 +42,7 @@ accelerate launch examples/ltx2/model_training/train.py \ --dataset_base_path ./models/train/LTX2-T2AV-noaudio_lora-splited-cache \ --height 512 \ --width 768 \ - --num_frames 49\ + --num_frames 121\ --dataset_repeat 100 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \ --learning_rate 1e-4 \ diff --git a/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh b/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh index 40dae1a..ebee83d 100644 --- a/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh +++ b/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh @@ -27,7 +27,7 @@ accelerate launch examples/ltx2/model_training/train.py \ --extra_inputs "input_audio" \ --height 512 \ --width 768 \ - --num_frames 49 \ + --num_frames 121 \ --dataset_repeat 1 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ --learning_rate 1e-4 \ @@ -46,7 +46,7 @@ accelerate launch examples/ltx2/model_training/train.py \ --extra_inputs "input_audio" \ --height 512 \ --width 768 \ - --num_frames 49 \ + --num_frames 121 \ --dataset_repeat 100 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \ --learning_rate 1e-4 \ diff --git a/examples/ltx2/model_training/train.py b/examples/ltx2/model_training/train.py index a994f7a..26a2925 100644 --- a/examples/ltx2/model_training/train.py +++ b/examples/ltx2/model_training/train.py @@ -118,10 +118,10 @@ if __name__ == "__main__": max_pixels=args.max_pixels, height=args.height, width=args.width, - height_division_factor=16, - width_division_factor=16, + height_division_factor=32, + width_division_factor=32, num_frames=args.num_frames, - time_division_factor=4, + time_division_factor=8, time_division_remainder=1, ), special_operator_map={ diff --git a/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py b/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py index a5da12d..6201ec1 100644 --- a/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py +++ b/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py @@ -27,7 +27,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ) prompt = "A beautiful sunset over the ocean." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -height, width, num_frames = 512, 768, 49 +height, width, num_frames = 512, 768, 121 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, diff --git a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py index 471a901..d0dab81 100644 --- a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py +++ b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py @@ -28,7 +28,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( pipe.load_lora(pipe.dit, "models/train/LTX2-T2AV_lora/epoch-4.safetensors") prompt = "A beautiful sunset over the ocean." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -height, width, num_frames = 512, 768, 49 +height, width, num_frames = 512, 768, 121 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, diff --git a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py index 4c2bccc..336b2bf 100644 --- a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py +++ b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py @@ -28,7 +28,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( pipe.load_lora(pipe.dit, "models/train/LTX2-T2AV-noaudio_lora/epoch-4.safetensors") prompt = "A beautiful sunset over the ocean." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -height, width, num_frames = 512, 768, 49 +height, width, num_frames = 512, 768, 121 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, From 5996c2b0689cdd7cc634fe6ae04b604ac7468784 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 27 Feb 2026 16:48:16 +0800 Subject: [PATCH 13/23] support inference --- diffsynth/pipelines/ltx2_audio_video.py | 94 +++++++++++++++++-- .../model_inference/LTX-2-I2AV-TwoStage.py | 1 - .../LTX-2-T2AV-IC-LoRA-Detailer.py | 77 +++++++++++++++ .../LTX-2-T2AV-IC-LoRA-Union-Control.py | 77 +++++++++++++++ 4 files changed, 238 insertions(+), 11 deletions(-) create mode 100644 examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py create mode 100644 examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index fc0b969..c662016 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -61,6 +61,7 @@ class LTX2AudioVideoPipeline(BasePipeline): LTX2AudioVideoUnit_InputAudioEmbedder(), LTX2AudioVideoUnit_InputVideoEmbedder(), LTX2AudioVideoUnit_InputImagesEmbedder(), + LTX2AudioVideoUnit_InContextVideoEmbedder(), ] self.model_fn = model_fn_ltx2 @@ -105,6 +106,8 @@ class LTX2AudioVideoPipeline(BasePipeline): def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm): if inputs_shared["use_two_stage_pipeline"]: + if inputs_shared.get("clear_lora_before_state_two", False): + self.clear_lora() latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"]) self.load_models_to_device('upsampler',) latent = self.upsampler(latent) @@ -112,11 +115,17 @@ class LTX2AudioVideoPipeline(BasePipeline): self.scheduler.set_timesteps(special_case="stage2") inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")}) denoise_mask_video = 1.0 + # input image if inputs_shared.get("input_images", None) is not None: latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents( latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"], inputs_shared["input_images_strength"], latent.clone()) inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video}) + # remove in-context video control in stage 2 + inputs_shared.pop("in_context_video_latents") + inputs_shared.pop("in_context_video_positions") + + # initialize latents for stage 2 inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[ "video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + ( @@ -145,11 +154,14 @@ class LTX2AudioVideoPipeline(BasePipeline): # Prompt prompt: str, negative_prompt: Optional[str] = "", - # Image-to-video denoising_strength: float = 1.0, + # Image-to-video input_images: Optional[list[Image.Image]] = None, input_images_indexes: Optional[list[int]] = None, input_images_strength: Optional[float] = 1.0, + # In-Context Video Control + in_context_videos: Optional[list[list[Image.Image]]] = None, + in_context_downsample_factor: Optional[int] = 2, # Randomness seed: Optional[int] = None, rand_device: Optional[str] = "cpu", @@ -157,6 +169,7 @@ class LTX2AudioVideoPipeline(BasePipeline): height: Optional[int] = 512, width: Optional[int] = 768, num_frames=121, + frame_rate=24, # Classifier-free guidance cfg_scale: Optional[float] = 3.0, # Scheduler @@ -169,6 +182,7 @@ class LTX2AudioVideoPipeline(BasePipeline): tile_overlap_in_frames: Optional[int] = 24, # Special Pipelines use_two_stage_pipeline: Optional[bool] = False, + clear_lora_before_state_two: Optional[bool] = False, use_distilled_pipeline: Optional[bool] = False, # progress_bar progress_bar_cmd=tqdm, @@ -185,12 +199,13 @@ class LTX2AudioVideoPipeline(BasePipeline): } inputs_shared = { "input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength, + "in_context_videos": in_context_videos, "in_context_downsample_factor": in_context_downsample_factor, "seed": seed, "rand_device": rand_device, - "height": height, "width": width, "num_frames": num_frames, + "height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate, "cfg_scale": cfg_scale, "tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels, "tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames, - "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, + "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two, "video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier, } for unit in self.units: @@ -417,8 +432,8 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit): class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( - input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"), - output_params=("video_noise", "audio_noise",), + input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate", "use_two_stage_pipeline"), + output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape") ) def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0): @@ -471,7 +486,6 @@ class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit): if pipe.scheduler.training: return {"video_latents": input_latents, "input_latents": input_latents} else: - # TODO: implement video-to-video raise NotImplementedError("Video-to-video not implemented yet.") class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit): @@ -495,14 +509,13 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit): if pipe.scheduler.training: return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape} else: - # TODO: implement video-to-video - raise NotImplementedError("Video-to-video not implemented yet.") + raise NotImplementedError("Audio-to-video not supported.") class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "num_frames", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"), - output_params=("video_latents"), + output_params=("video_latents", "denoise_mask_video", "input_latents_video", "stage2_input_latents"), onload_model_names=("video_vae_encoder") ) @@ -537,6 +550,54 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): return output_dicts +class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"), + output_params=("in_context_video_latents", "in_context_video_positions"), + onload_model_names=("video_vae_encoder") + ) + + def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline=True): + if in_context_video is None or len(in_context_video) == 0: + raise ValueError("In-context video is None or empty.") + in_context_video = in_context_video[:num_frames] + expected_height = height // in_context_downsample_factor // 2 if use_two_stage_pipeline else height // in_context_downsample_factor + expected_width = width // in_context_downsample_factor // 2 if use_two_stage_pipeline else width // in_context_downsample_factor + current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video) + h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f) + if current_h != h or current_w != w: + in_context_video = [img.resize((w, h)) for img in in_context_video] + if current_f != f: + # pad black frames at the end + in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f) + return in_context_video + + def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=True): + if in_context_videos is None or len(in_context_videos) == 0: + return {} + else: + pipe.load_models_to_device(self.onload_model_names) + latents, positions = [], [] + for in_context_video in in_context_videos: + in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline) + in_context_video = pipe.preprocess_video(in_context_video) + in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device) + + latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(in_context_latents.shape), device=pipe.device) + video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float() + video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate + video_positions[:, 1, ...] *= in_context_downsample_factor # height axis + video_positions[:, 2, ...] *= in_context_downsample_factor # width axis + video_positions = video_positions.to(pipe.torch_dtype) + + latents.append(in_context_latents) + positions.append(video_positions) + latents = torch.cat(latents, dim=1) + positions = torch.cat(positions, dim=1) + return {"in_context_video_latents": latents, "in_context_video_positions": positions} + + def model_fn_ltx2( dit: LTXModel, video_latents=None, @@ -549,6 +610,8 @@ def model_fn_ltx2( audio_patchifier=None, timestep=None, denoise_mask_video=None, + in_context_video_latents=None, + in_context_video_positions=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, @@ -558,16 +621,25 @@ def model_fn_ltx2( # patchify b, c_v, f, h, w = video_latents.shape video_latents = video_patchifier.patchify(video_latents) + seq_len_video = video_latents.shape[1] video_timesteps = timestep.repeat(1, video_latents.shape[1], 1) if denoise_mask_video is not None: video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps + + if in_context_video_latents is not None: + in_context_video_latents = video_patchifier.patchify(in_context_video_latents) + in_context_video_timesteps = timestep.repeat(1, in_context_video_latents.shape[1], 1) * 0. + video_latents = torch.cat([video_latents, in_context_video_latents], dim=1) + video_positions = torch.cat([video_positions, in_context_video_positions], dim=2) + video_timesteps = torch.cat([video_timesteps, in_context_video_timesteps], dim=1) + if audio_latents is not None: _, c_a, _, mel_bins = audio_latents.shape audio_latents = audio_patchifier.patchify(audio_latents) audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1) else: audio_timesteps = None - #TODO: support gradient checkpointing in training + vx, ax = dit( video_latents=video_latents, video_positions=video_positions, @@ -580,6 +652,8 @@ def model_fn_ltx2( use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) + + vx = vx[:, :seq_len_video, ...] # unpatchify vx = video_patchifier.unpatchify_video(vx, f, h, w) ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None diff --git a/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py index bd86b34..0465803 100644 --- a/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py +++ b/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py @@ -46,7 +46,6 @@ negative_prompt = ( "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." ) height, width, num_frames = 512 * 2, 768 * 2, 121 -height, width, num_frames = 512 * 2, 768 * 2, 121 dataset_snapshot_download( dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py b/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py new file mode 100644 index 0000000..687e216 --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py @@ -0,0 +1,77 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data import VideoData +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora(pipe.dit, ModelConfig(model_id="Lightricks/LTX-2-19b-IC-LoRA-Detailer", origin_file_pattern="ltx-2-19b-ic-lora-detailer.safetensors")) +dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset") + +prompt = "[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +ref_scale_factor = 1 +frame_rate = 24 +# the frame rate of the video should better be the same with the reference video +# the spatial resolution of the first frame should be the resolution of stage 1 video generation divided by ref_scale_factor +input_video = VideoData("data/example_video_dataset/ltx2/video1.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2) +input_video = input_video.raw_data() +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + in_context_videos=[input_video], + in_context_downsample_factor=ref_scale_factor, + tiled=True, + use_two_stage_pipeline=True, + clear_lora_before_state_two=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage_iclora.mp4', + fps=frame_rate, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py b/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py new file mode 100644 index 0000000..3021306 --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py @@ -0,0 +1,77 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data import VideoData +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora(pipe.dit, ModelConfig(model_id="Lightricks/LTX-2-19b-IC-LoRA-Union-Control", origin_file_pattern="ltx-2-19b-ic-lora-union-control-ref0.5.safetensors")) +dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset") + +prompt = "[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +ref_scale_factor = 2 +frame_rate = 24 +# the frame rate of the video should better be the same with the reference video +# the spatial resolution of the first frame should be the resolution of stage 1 video generation divided by ref_scale_factor +input_video = VideoData("data/example_video_dataset/ltx2/depth_video.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2) +input_video = input_video.raw_data() +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + in_context_videos=[input_video], + in_context_downsample_factor=ref_scale_factor, + tiled=True, + use_two_stage_pipeline=True, + clear_lora_before_state_two=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage_iclora.mp4', + fps=frame_rate, + audio_sample_rate=24000, +) From 8b9a094c1b887bca797bf83ad3f576f54842281b Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 27 Feb 2026 18:43:53 +0800 Subject: [PATCH 14/23] ltx iclora train --- diffsynth/diffusion/base_pipeline.py | 11 ++- diffsynth/pipelines/ltx2_audio_video.py | 2 +- .../LTX-2-T2AV-IC-LoRA-Detailer.py | 77 +++++++++++++++++++ .../LTX-2-T2AV-IC-LoRA-Union-Control.py | 77 +++++++++++++++++++ .../lora/LTX-2-T2AV-IC-LoRA-splited.sh | 39 ++++++++++ examples/ltx2/model_training/train.py | 23 +++--- .../validate_lora/LTX-2-T2AV-IC-LoRA.py | 56 ++++++++++++++ 7 files changed, 271 insertions(+), 14 deletions(-) create mode 100644 examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py create mode 100644 examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py create mode 100644 examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh create mode 100644 examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index 7d41cac..4d046ab 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -94,20 +94,23 @@ class BasePipeline(torch.nn.Module): return self - def check_resize_height_width(self, height, width, num_frames=None): + def check_resize_height_width(self, height, width, num_frames=None, verbose=1): # Shape check if height % self.height_division_factor != 0: height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor - print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") + if verbose > 0: + print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") if width % self.width_division_factor != 0: width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor - print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") + if verbose > 0: + print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") if num_frames is None: return height, width else: if num_frames % self.time_division_factor != self.time_division_remainder: num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder - print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") + if verbose > 0: + print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") return height, width, num_frames diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index c662016..f18d785 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -565,7 +565,7 @@ class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit): expected_height = height // in_context_downsample_factor // 2 if use_two_stage_pipeline else height // in_context_downsample_factor expected_width = width // in_context_downsample_factor // 2 if use_two_stage_pipeline else width // in_context_downsample_factor current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video) - h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f) + h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0) if current_h != h or current_w != w: in_context_video = [img.resize((w, h)) for img in in_context_video] if current_f != f: diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py new file mode 100644 index 0000000..eccddd7 --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py @@ -0,0 +1,77 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data import VideoData +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora(pipe.dit, ModelConfig(model_id="Lightricks/LTX-2-19b-IC-LoRA-Detailer", origin_file_pattern="ltx-2-19b-ic-lora-detailer.safetensors")) +dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset") + +prompt = "[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +ref_scale_factor = 1 +frame_rate = 24 +# the frame rate of the video should better be the same with the reference video +# the spatial resolution of the first frame should be the resolution of stage 1 video generation divided by ref_scale_factor +input_video = VideoData("data/example_video_dataset/ltx2/video1.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2) +input_video = input_video.raw_data() +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + in_context_videos=[input_video], + in_context_downsample_factor=ref_scale_factor, + tiled=True, + use_two_stage_pipeline=True, + clear_lora_before_state_two=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage_iclora.mp4', + fps=frame_rate, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py new file mode 100644 index 0000000..37515d8 --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py @@ -0,0 +1,77 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data import VideoData +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora(pipe.dit, ModelConfig(model_id="Lightricks/LTX-2-19b-IC-LoRA-Union-Control", origin_file_pattern="ltx-2-19b-ic-lora-union-control-ref0.5.safetensors")) +dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset") + +prompt = "[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +ref_scale_factor = 2 +frame_rate = 24 +# the frame rate of the video should better be the same with the reference video +# the spatial resolution of the first frame should be the resolution of stage 1 video generation divided by ref_scale_factor +input_video = VideoData("data/example_video_dataset/ltx2/depth_video.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2) +input_video = input_video.raw_data() +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + in_context_videos=[input_video], + in_context_downsample_factor=ref_scale_factor, + tiled=True, + use_two_stage_pipeline=True, + clear_lora_before_state_two=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage_iclora.mp4', + fps=frame_rate, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh b/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh new file mode 100644 index 0000000..c4fdd86 --- /dev/null +++ b/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh @@ -0,0 +1,39 @@ +# Splited Training +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av_iclora.json \ + --data_file_keys "video,input_audio,in_context_videos" \ + --extra_inputs "input_audio,in_context_videos,in_context_downsample_factor,frame_rate" \ + --height 512 \ + --width 768 \ + --num_frames 81 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2-T2AV-IC-LoRA-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_k,to_q,to_v,to_out.0" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --task "sft:data_process" + +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path ./models/train/LTX2-T2AV-IC-LoRA-splited-cache \ + --data_file_keys "video,input_audio,in_context_videos" \ + --extra_inputs "input_audio,in_context_videos,in_context_downsample_factor,frame_rate" \ + --height 512 \ + --width 768 \ + --num_frames 81 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2-T2AV-IC-LoRA" \ + --lora_base_model "dit" \ + --lora_target_modules "to_k,to_q,to_v,to_out.0" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --task "sft:train" diff --git a/examples/ltx2/model_training/train.py b/examples/ltx2/model_training/train.py index 26a2925..3eb023a 100644 --- a/examples/ltx2/model_training/train.py +++ b/examples/ltx2/model_training/train.py @@ -1,7 +1,6 @@ import torch, os, argparse, accelerate, warnings from diffsynth.core import UnifiedDataset -from diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath -from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath, RouteByType, SequencialProcess from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig from diffsynth.diffusion import * os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -69,6 +68,7 @@ class LTX2TrainingModule(DiffusionTrainingModule): "height": data["video"][0].size[1], "width": data["video"][0].size[0], "num_frames": len(data["video"]), + "frame_rate": data.get("frame_rate", 24), # Please do not modify the following parameters # unless you clearly know what this will cause. "cfg_scale": 1, @@ -108,12 +108,7 @@ if __name__ == "__main__": gradient_accumulation_steps=args.gradient_accumulation_steps, kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], ) - dataset = UnifiedDataset( - base_path=args.dataset_base_path, - metadata_path=args.dataset_metadata_path, - repeat=args.dataset_repeat, - data_file_keys=args.data_file_keys.split(","), - main_data_operator=UnifiedDataset.default_video_operator( + video_processor = UnifiedDataset.default_video_operator( base_path=args.dataset_base_path, max_pixels=args.max_pixels, height=args.height, @@ -123,9 +118,19 @@ if __name__ == "__main__": num_frames=args.num_frames, time_division_factor=8, time_division_remainder=1, - ), + ) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=video_processor, special_operator_map={ "input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(duration=float(args.num_frames) / float(args.frame_rate)), + "in_context_videos": RouteByType(operator_map=[ + (str, video_processor), + (list, SequencialProcess(video_processor)), + ]), } ) model = LTX2TrainingModule( diff --git a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py new file mode 100644 index 0000000..9d793e0 --- /dev/null +++ b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py @@ -0,0 +1,56 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data import VideoData + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) +pipe.load_lora(pipe.dit, "./models/train/LTX2-T2AV-IC-LoRA/epoch-4.safetensors") +prompt = "[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing" +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 81 +ref_scale_factor = 2 +frame_rate = 24 +input_video = VideoData("data/examples/wan/depth_video.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2) +input_video = input_video.raw_data() +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + tiled=True, + in_context_videos=[input_video], + in_context_downsample_factor=ref_scale_factor, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage_ic.mp4', + fps=frame_rate, + audio_sample_rate=24000, +) From 5ca74923e830af7024f12143f5e406b02cae2262 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Sat, 28 Feb 2026 10:56:08 +0800 Subject: [PATCH 15/23] add readme --- README.md | 2 ++ README_zh.md | 2 ++ docs/en/Model_Details/LTX-2.md | 2 ++ docs/zh/Model_Details/LTX-2.md | 2 ++ 4 files changed, 8 insertions(+) diff --git a/README.md b/README.md index 5fe2017..15d6597 100644 --- a/README.md +++ b/README.md @@ -645,6 +645,8 @@ Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/) | Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | |-|-|-|-|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)| +|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)| +|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)| |[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| diff --git a/README_zh.md b/README_zh.md index 74843db..ec3cae7 100644 --- a/README_zh.md +++ b/README_zh.md @@ -645,6 +645,8 @@ LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/) |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)| +|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)| +|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)| |[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| diff --git a/docs/en/Model_Details/LTX-2.md b/docs/en/Model_Details/LTX-2.md index 68ab351..0123dfe 100644 --- a/docs/en/Model_Details/LTX-2.md +++ b/docs/en/Model_Details/LTX-2.md @@ -112,6 +112,8 @@ write_video_audio_ltx2( |Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training| |-|-|-|-|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)| +|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)| +|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)| |[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| diff --git a/docs/zh/Model_Details/LTX-2.md b/docs/zh/Model_Details/LTX-2.md index 558de9d..0c23adb 100644 --- a/docs/zh/Model_Details/LTX-2.md +++ b/docs/zh/Model_Details/LTX-2.md @@ -112,6 +112,8 @@ write_video_audio_ltx2( |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)| +|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)| +|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)| |[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| From 1a380a6b62fc79d6772f6bfd7079617aeb435321 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Sat, 28 Feb 2026 11:09:10 +0800 Subject: [PATCH 16/23] minor fix --- diffsynth/pipelines/ltx2_audio_video.py | 4 ++-- .../ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index f18d785..2e0b2cd 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -122,8 +122,8 @@ class LTX2AudioVideoPipeline(BasePipeline): inputs_shared["input_images_strength"], latent.clone()) inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video}) # remove in-context video control in stage 2 - inputs_shared.pop("in_context_video_latents") - inputs_shared.pop("in_context_video_positions") + inputs_shared.pop("in_context_video_latents", None) + inputs_shared.pop("in_context_video_positions", None) # initialize latents for stage 2 inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[ diff --git a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py index 9d793e0..d6eda1a 100644 --- a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py +++ b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py @@ -33,7 +33,7 @@ negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast height, width, num_frames = 512, 768, 81 ref_scale_factor = 2 frame_rate = 24 -input_video = VideoData("data/examples/wan/depth_video.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2) +input_video = VideoData("data/example_video_dataset/ltx2/depth_video.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2) input_video = input_video.raw_data() video, audio = pipe( prompt=prompt, From b3f6c3275f4884acc0a2b7cb0da873c8d360028d Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 2 Mar 2026 10:58:02 +0800 Subject: [PATCH 17/23] update ltx-2 --- docs/en/index.rst | 1 + docs/zh/index.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/en/index.rst b/docs/en/index.rst index ab195ef..ca38620 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -27,6 +27,7 @@ Welcome to DiffSynth-Studio's Documentation Model_Details/Qwen-Image Model_Details/FLUX2 Model_Details/Z-Image + Model_Details/LTX-2 .. toctree:: :maxdepth: 2 diff --git a/docs/zh/index.rst b/docs/zh/index.rst index 4e82d3e..d2afefc 100644 --- a/docs/zh/index.rst +++ b/docs/zh/index.rst @@ -27,6 +27,7 @@ Model_Details/Qwen-Image Model_Details/FLUX2 Model_Details/Z-Image + Model_Details/LTX-2 .. toctree:: :maxdepth: 2 From 6d671db5d250bda8458ba60e375df57ac7c1abbd Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:49:02 +0800 Subject: [PATCH 18/23] Support Anima (#1317) * support Anima Co-authored-by: mi804 <1576993271@qq.com> --- README.md | 57 + README_zh.md | 57 + diffsynth/configs/model_configs.py | 18 +- .../configs/vram_management_module_maps.py | 6 + diffsynth/models/anima_dit.py | 1304 +++++++++++++++++ diffsynth/pipelines/anima_image.py | 261 ++++ .../utils/state_dict_converters/anima_dit.py | 6 + docs/en/Model_Details/Anima.md | 139 ++ docs/en/index.rst | 1 + docs/zh/Model_Details/Anima.md | 139 ++ docs/zh/index.rst | 1 + .../anima/model_inference/anima-preview.py | 19 + .../model_inference_low_vram/anima-preview.py | 30 + .../model_training/full/anima-preview.sh | 14 + .../model_training/lora/anima-preview.sh | 16 + examples/anima/model_training/train.py | 145 ++ .../validate_full/anima-preview.py | 21 + .../validate_lora/anima-preview.py | 19 + 18 files changed, 2252 insertions(+), 1 deletion(-) create mode 100644 diffsynth/models/anima_dit.py create mode 100644 diffsynth/pipelines/anima_image.py create mode 100644 diffsynth/utils/state_dict_converters/anima_dit.py create mode 100644 docs/en/Model_Details/Anima.md create mode 100644 docs/zh/Model_Details/Anima.md create mode 100644 examples/anima/model_inference/anima-preview.py create mode 100644 examples/anima/model_inference_low_vram/anima-preview.py create mode 100644 examples/anima/model_training/full/anima-preview.sh create mode 100644 examples/anima/model_training/lora/anima-preview.sh create mode 100644 examples/anima/model_training/train.py create mode 100644 examples/anima/model_training/validate_full/anima-preview.py create mode 100644 examples/anima/model_training/validate_lora/anima-preview.py diff --git a/README.md b/README.md index 15d6597..704e9ac 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,9 @@ We believe that a well-developed open-source code framework can lower the thresh > DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update. > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. + +- **March 2, 2026** Added support for [Anima](https://modelscope.cn/models/circlestone-labs/Anima). For details, please refer to the [documentation](docs/en/Model_Details/Anima.md). This is an interesting anime-style image generation model. We look forward to its future updates. + - **February 26, 2026** Added full and lora training support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details. - **February 10, 2026** Added inference support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details. Support for model training will be implemented in the future. @@ -343,6 +346,60 @@ Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/) +#### Anima: [/docs/en/Model_Details/Anima.md](/docs/en/Model_Details/Anima.md) + +
+ +Quick Start + +Run the following code to quickly load the [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 8GB VRAM. + +```python +from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = AnimaImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," +image = pipe(prompt, seed=0, num_inference_steps=50) +image.save("image.jpg") +``` + +
+ +
+ +Examples + +Example code for Anima is located at: [/examples/anima/](/examples/anima/) + +| Model ID | Inference | Low VRAM Inference | Full Training | Validation after Full Training | LoRA Training | Validation after LoRA Training | +|-|-|-|-|-|-|-| +|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)| + +
+ #### Qwen-Image: [/docs/en/Model_Details/Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)
diff --git a/README_zh.md b/README_zh.md index ec3cae7..9c95503 100644 --- a/README_zh.md +++ b/README_zh.md @@ -32,6 +32,9 @@ DiffSynth 目前包括两个开源项目: > DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。 > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 + +- **2026年3月2日** 新增对[Anima](https://modelscope.cn/models/circlestone-labs/Anima)的支持,详见[文档](docs/zh/Model_Details/Anima.md)。这是一个有趣的动漫风格图像生成模型,我们期待其后续的模型更新。 + - **2026年2月26日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型全量微调与LoRA训练支持,详见[文档](docs/zh/Model_Details/LTX-2.md)。 - **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持,详见[文档](docs/zh/Model_Details/LTX-2.md),后续将推进模型训练的支持。 @@ -343,6 +346,60 @@ FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
+#### Anima: [/docs/zh/Model_Details/Anima.md](/docs/zh/Model_Details/Anima.md) + +
+ +快速开始 + +运行以下代码可以快速加载 [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。 + +```python +from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = AnimaImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," +image = pipe(prompt, seed=0, num_inference_steps=50) +image.save("image.jpg") +``` + +
+ +
+ +示例代码 + +Anima 的示例代码位于:[/examples/anima/](/examples/anima/) + +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)| + +
+ #### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index fbca133..f9fa595 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -719,4 +719,20 @@ ltx2_series = [ "model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler", }, ] -MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series +anima_series = [ + { + # Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors") + "model_hash": "a9995952c2d8e63cf82e115005eb61b9", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + "extra_kwargs": {"model_size": "0.6B"}, + }, + { + # Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors") + "model_hash": "417673936471e79e31ed4d186d7a3f4a", + "model_name": "anima_dit", + "model_class": "diffsynth.models.anima_dit.AnimaDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter", + } +] +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index 0f360ef..d86f5fa 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -243,4 +243,10 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", }, + "diffsynth.models.anima_dit.AnimaDiT": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, } diff --git a/diffsynth/models/anima_dit.py b/diffsynth/models/anima_dit.py new file mode 100644 index 0000000..dbd1407 --- /dev/null +++ b/diffsynth/models/anima_dit.py @@ -0,0 +1,1304 @@ +# original code from: comfy/ldm/cosmos/predict2.py + +import torch +from torch import nn +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +import logging +from typing import Callable, Optional, Tuple, List +import math +from torchvision import transforms +from ..core.attention import attention_forward +from ..core.gradient import gradient_checkpoint_forward + + +class VideoPositionEmb(nn.Module): + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor: + """ + It delegates the embedding generation to generate_embeddings function. + """ + B_T_H_W_C = x_B_T_H_W_C.shape + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype) + + return embeddings + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None): + raise NotImplementedError + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor: + """ + Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted. + + Args: + x (torch.Tensor): The input tensor to normalize. + dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first. + eps (float, optional): A small constant to ensure numerical stability during division. + + Returns: + torch.Tensor: The normalized tensor. + """ + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +class LearnablePosEmbAxis(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + device=None, + dtype=None, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype)) + self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype)) + self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype)) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor: + B, T, H, W, _ = B_T_H_W_C + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype) + emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype) + emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype) + emb = ( + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) + + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) + + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + else: + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + return normalize(emb, dim=-1, eps=1e-6) + + +class VideoRopePosition3DEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + enable_fps_modulation: bool = True, + device=None, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + self.enable_fps_modulation = enable_fps_modulation + + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + device=None, + dtype=None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. + + Returns: + Not specified in the original code snippet. + """ + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device)) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device)) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device)) + + B, T, H, W, _ = B_T_H_W_C + seq = torch.arange(max(H, W, T), dtype=torch.float, device=device) + uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max()) + assert ( + uniform_fps or B == 1 or T == 1 + ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" + half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs) + half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs) + + # apply sequence scaling in temporal dimension + if fps is None or self.enable_fps_modulation is False: # image case + half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs) + else: + half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) + + half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1) + half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1) + half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W), + repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W), + repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H), + ] + , dim=-2, + ) + + return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float() + + +def apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, +) -> torch.Tensor: + t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() + t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] + t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) + return t_out + + +# ---------------------- Feed Forward Network ----------------------- +class GPT2FeedForward(nn.Module): + def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None: + super().__init__() + self.activation = nn.GELU() + self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype) + self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype) + + self._layer_id = None + self._dim = d_model + self._hidden_dim = d_ff + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer1(x) + + x = self.activation(x) + x = self.layer2(x) + return x + + +def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: + """Computes multi-head attention using PyTorch's native implementation. + + This function provides a PyTorch backend alternative to Transformer Engine's attention operation. + It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product + attention, and rearranges the output back to the original format. + + The input tensor names use the following dimension conventions: + + - B: batch size + - S: sequence length + - H: number of attention heads + - D: head dimension + + Args: + q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim) + k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim) + v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim) + + Returns: + Attention output tensor with shape (batch, seq_len, n_heads * head_dim) + """ + in_q_shape = q_B_S_H_D.shape + in_k_shape = k_B_S_H_D.shape + q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) + k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + return attention_forward(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, out_pattern="b s (n d)") + + +class Attention(nn.Module): + """ + A flexible attention module supporting both self-attention and cross-attention mechanisms. + + This module implements a multi-head attention layer that can operate in either self-attention + or cross-attention mode. The mode is determined by whether a context dimension is provided. + The implementation uses scaled dot-product attention and supports optional bias terms and + dropout regularization. + + Args: + query_dim (int): The dimensionality of the query vectors. + context_dim (int, optional): The dimensionality of the context (key/value) vectors. + If None, the module operates in self-attention mode using query_dim. Default: None + n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8 + head_dim (int, optional): The dimension of each attention head. Default: 64 + dropout (float, optional): Dropout probability applied to the output. Default: 0.0 + qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd" + backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine" + + Examples: + >>> # Self-attention with 512 dimensions and 8 heads + >>> self_attn = Attention(query_dim=512) + >>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim) + >>> out = self_attn(x) # (32, 16, 512) + + >>> # Cross-attention + >>> cross_attn = Attention(query_dim=512, context_dim=256) + >>> query = torch.randn(32, 16, 512) + >>> context = torch.randn(32, 8, 256) + >>> out = cross_attn(query, context) # (32, 16, 512) + """ + + def __init__( + self, + query_dim: int, + context_dim: Optional[int] = None, + n_heads: int = 8, + head_dim: int = 64, + dropout: float = 0.0, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + logging.debug( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{n_heads} heads with a dimension of {head_dim}." + ) + self.is_selfattn = context_dim is None # self attention + + context_dim = query_dim if context_dim is None else context_dim + inner_dim = head_dim * n_heads + + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.v_norm = nn.Identity() + + self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity() + + self.attn_op = torch_attention_op + + self._query_dim = query_dim + self._context_dim = context_dim + self._inner_dim = inner_dim + + def compute_qkv( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_proj(x) + context = x if context is None else context + k = self.k_proj(context) + v = self.v_proj(context) + q, k, v = map( + lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim), + (q, k, v), + ) + + def apply_norm_and_rotary_pos_emb( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_norm(q) + k = self.k_norm(k) + v = self.v_norm(v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb) + k = apply_rotary_pos_emb(k, rope_emb) + return q, k, v + + q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) + + return q, k, v + + def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: + result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D] + return self.output_dropout(self.output_proj(result)) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, + ) -> torch.Tensor: + """ + Args: + x (Tensor): The query tensor of shape [B, Mq, K] + context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None + """ + q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) + return self.compute_attention(q, k, v, transformer_options=transformer_options) + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor: + assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}" + timesteps = timesteps_B_T.flatten().float() + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1]) + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None): + super().__init__() + logging.debug( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.in_dim = in_features + self.out_dim = out_features + self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype) + else: + self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype) + + def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_T_3D = emb + emb_B_T_D = sample + else: + adaln_lora_B_T_3D = None + emb_B_T_D = emb + + return emb_B_T_D, adaln_lora_B_T_3D + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches + and embedding each patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + """ + + def __init__( + self, + spatial_patch_size: int, + temporal_patch_size: int, + in_channels: int = 3, + out_channels: int = 768, + device=None, dtype=None, operations=None + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + operations.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype + ), + ) + self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert ( + H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}" + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size: int, + spatial_patch_size: int, + temporal_patch_size: int, + out_channels: int, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, dtype=None, operations=None + ): + super().__init__() + self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = operations.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + if use_adaln_lora: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype) + ) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_T_3D is not None + shift_B_T_D, scale_B_T_D = ( + self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] + ).chunk(2, dim=-1) + else: + shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) + + shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange( + scale_B_T_D, "b t d -> b t 1 1 d" + ) + + def _fn( + _x_B_T_H_W_D: torch.Tensor, + _norm_layer: nn.Module, + _scale_B_T_1_1_D: torch.Tensor, + _shift_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D) + x_B_T_H_W_O = self.linear(x_B_T_H_W_D) + return x_B_T_H_W_O + + +class Block(nn.Module): + """ + A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation. + Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation. + + Parameters: + x_dim (int): Dimension of input features + context_dim (int): Dimension of context features for cross-attention + num_heads (int): Number of attention heads + mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0 + use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False + adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256 + + The block applies the following sequence: + 1. Self-attention with AdaLN modulation + 2. Cross-attention with AdaLN modulation + 3. MLP with AdaLN modulation + + Each component uses skip connections and layer normalization. + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.x_dim = x_dim + self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations) + + self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations + ) + + self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations) + + self.use_adaln_lora = use_adaln_lora + if self.use_adaln_lora: + self.adaln_modulation_self_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_cross_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_mlp = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + crossattn_emb: torch.Tensor, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, + ) -> torch.Tensor: + residual_dtype = x_B_T_H_W_D.dtype + compute_dtype = emb_B_T_D.dtype + if extra_per_block_pos_emb is not None: + x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb + + if self.use_adaln_lora: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( + self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( + self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( + self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + else: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) + + # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting + shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d") + scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d") + gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d") + scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d") + gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d") + scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d") + gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d") + + B, T, H, W, D = x_B_T_H_W_D.shape + + def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_self_attn, + scale_self_attn_B_T_1_1_D, + shift_self_attn_B_T_1_1_D, + ) + result_B_T_H_W_D = rearrange( + self.self_attn( + # normalized_x_B_T_HW_D, + rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), + None, + rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) + + def _x_fn( + _x_B_T_H_W_D: torch.Tensor, + layer_norm_cross_attn: Callable, + _scale_cross_attn_B_T_1_1_D: torch.Tensor, + _shift_cross_attn_B_T_1_1_D: torch.Tensor, + transformer_options: Optional[dict] = {}, + ) -> torch.Tensor: + _normalized_x_B_T_H_W_D = _fn( + _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D + ) + _result_B_T_H_W_D = rearrange( + self.cross_attn( + rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), + crossattn_emb, + rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + return _result_B_T_H_W_D + + result_B_T_H_W_D = _x_fn( + x_B_T_H_W_D, + self.layer_norm_cross_attn, + scale_cross_attn_B_T_1_1_D, + shift_cross_attn_B_T_1_1_D, + transformer_options=transformer_options, + ) + x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_mlp, + scale_mlp_B_T_1_1_D, + shift_mlp_B_T_1_1_D, + ) + result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype)) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) + return x_B_T_H_W_D + + +class MiniTrainDIT(nn.Module): + """ + A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1) + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + + Args: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of transformer blocks. + num_heads (int): Number of heads in the multi-head attention layers. + mlp_ratio (float): Expansion ratio for MLP blocks. + crossattn_emb_channels (int): Number of embedding channels for cross-attention. + pos_emb_cls (str): Type of positional embeddings. + pos_emb_learnable (bool): Whether positional embeddings are learnable. + pos_emb_interpolation (str): Method for interpolating positional embeddings. + min_fps (int): Minimum frames per second. + max_fps (int): Maximum frames per second. + use_adaln_lora (bool): Whether to use AdaLN-LoRA. + adaln_lora_dim (int): Dimension for AdaLN-LoRA. + rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. + rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. + rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. + extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. + extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. + extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. + extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: int, # tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + # cross attention settings + crossattn_emb_channels: int = 1024, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + min_fps: int = 1, + max_fps: int = 30, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = False, + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + rope_enable_fps_modulation: bool = True, + image_model=None, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + self.dtype = dtype + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.min_fps = min_fps + self.max_fps = max_fps + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + self.rope_enable_fps_modulation = rope_enable_fps_modulation + + self.build_pos_embed(device=device, dtype=dtype) + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,), + ) + + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + device=device, dtype=dtype, operations=operations, + ) + + self.blocks = nn.ModuleList( + [ + Block( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + for _ in range(num_blocks) + ] + ) + + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + + self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype) + + def build_pos_embed(self, device=None, dtype=None) -> None: + if self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + enable_fps_modulation=self.rope_enable_fps_modulation, + device=device, + ) + self.pos_embedder = cls_type( + **kwargs, # type: ignore + ) + + if self.extra_per_block_abs_pos_emb: + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + kwargs["device"] = device + kwargs["dtype"] = dtype + self.extra_pos_embedder = LearnablePosEmbAxis( + **kwargs, # type: ignore + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the + `self.pos_embedder` with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + if padding_mask is None: + padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device) + else: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D] + + return x_B_T_H_W_D, None, extra_pos_emb + + def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor: + x_B_C_Tt_Hp_Wp = rearrange( + x_B_T_H_W_M, + "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + t=self.patch_temporal, + ) + return x_B_C_Tt_Hp_Wp + + def pad_to_patch_size(self, img, patch_size=(2, 2), padding_mode="circular"): + if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()): + padding_mode = "reflect" + + pad = () + for i in range(img.ndim - 2): + pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad + + return torch.nn.functional.pad(img, pad, mode=padding_mode) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, + ): + orig_shape = list(x.shape) + x = self.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial)) + x_B_C_T_H_W = x + timesteps_B_T = timesteps + crossattn_emb = context + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + """ + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x_B_C_T_H_W, + fps=fps, + padding_mask=padding_mask, + ) + + if timesteps_B_T.ndim == 1: + timesteps_B_T = timesteps_B_T.unsqueeze(1) + t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype)) + t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D) + + # for logging purpose + affline_scale_log_info = {} + affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach() + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = t_embedding_B_T_D + self.crossattn_emb = crossattn_emb + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}" + + block_kwargs = { + "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), + "adaln_lora_B_T_3D": adaln_lora_B_T_3D, + "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + "transformer_options": kwargs.get("transformer_options", {}), + } + + # The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream + # in fp32, but run attention and MLP modules in fp16. + # An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable + # quality degradation and visual artifacts. + if x_B_T_H_W_D.dtype == torch.float16: + x_B_T_H_W_D = x_B_T_H_W_D.float() + + for block in self.blocks: + x_B_T_H_W_D = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x_B_T_H_W_D=x_B_T_H_W_D, + emb_B_T_D=t_embedding_B_T_D, + crossattn_emb=crossattn_emb, + **block_kwargs, + ) + + x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) + x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]] + return x_B_C_Tt_Hp_Wp + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb2(x, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, head_dim): + super().__init__() + self.rope_theta = 10000 + inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LLMAdapterAttention(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None): + super().__init__() + + inner_dim = head_dim * n_heads + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + + self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None): + context = x if context is None else context + input_shape = x.shape[:-1] + q_shape = (*input_shape, self.n_heads, self.head_dim) + context_shape = context.shape[:-1] + kv_shape = (*context_shape, self.n_heads, self.head_dim) + + query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2) + value_states = self.v_proj(context).view(kv_shape).transpose(1, 2) + + if position_embeddings is not None: + assert position_embeddings_context is not None + cos, sin = position_embeddings + query_states = apply_rotary_pos_emb2(query_states, cos, sin) + cos, sin = position_embeddings_context + key_states = apply_rotary_pos_emb2(key_states, cos, sin) + + attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) + + attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + def init_weights(self): + torch.nn.init.zeros_(self.o_proj.weight) + + +class LLMAdapterTransformerBlock(nn.Module): + def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None): + super().__init__() + self.use_self_attn = use_self_attn + + if self.use_self_attn: + self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.self_attn = LLMAdapterAttention( + query_dim=model_dim, + context_dim=model_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = LLMAdapterAttention( + query_dim=model_dim, + context_dim=source_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.mlp = nn.Sequential( + operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype), + nn.GELU(), + operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype) + ) + + def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None): + if self.use_self_attn: + normed = self.norm_self_attn(x) + attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings) + x = x + attn_out + + normed = self.norm_cross_attn(x) + attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + x = x + attn_out + + x = x + self.mlp(self.norm_mlp(x)) + return x + + def init_weights(self): + torch.nn.init.zeros_(self.mlp[2].weight) + self.cross_attn.init_weights() + + +class LLMAdapter(nn.Module): + def __init__( + self, + source_dim=1024, + target_dim=1024, + model_dim=1024, + num_layers=6, + num_heads=16, + use_self_attn=True, + layer_norm=False, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + + self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype) + if model_dim != target_dim: + self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype) + else: + self.in_proj = nn.Identity() + self.rotary_emb = RotaryEmbedding(model_dim//num_heads) + self.blocks = nn.ModuleList([ + LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers) + ]) + self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype) + self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype) + + def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None): + if target_attention_mask is not None: + target_attention_mask = target_attention_mask.to(torch.bool) + if target_attention_mask.ndim == 2: + target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1) + + if source_attention_mask is not None: + source_attention_mask = source_attention_mask.to(torch.bool) + if source_attention_mask.ndim == 2: + source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) + + context = source_hidden_states + x = self.in_proj(self.embed(target_input_ids).to(context.dtype)) + position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) + position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) + position_embeddings = self.rotary_emb(x, position_ids) + position_embeddings_context = self.rotary_emb(x, position_ids_context) + for block in self.blocks: + x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + return self.norm(self.out_proj(x)) + + +class AnimaDiT(MiniTrainDIT): + def __init__(self): + kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn} + super().__init__(**kwargs) + self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) + + def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None): + if text_ids is not None: + out = self.llm_adapter(text_embeds, text_ids) + if t5xxl_weights is not None: + out = out * t5xxl_weights + + if out.shape[1] < 512: + out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1])) + return out + else: + return text_embeds + + def forward( + self, + x, timesteps, context, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs + ): + t5xxl_ids = kwargs.pop("t5xxl_ids", None) + if t5xxl_ids is not None: + context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None)) + return super().forward( + x, timesteps, context, + use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + **kwargs + ) diff --git a/diffsynth/pipelines/anima_image.py b/diffsynth/pipelines/anima_image.py new file mode 100644 index 0000000..732ede5 --- /dev/null +++ b/diffsynth/pipelines/anima_image.py @@ -0,0 +1,261 @@ +import torch, math +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange +import numpy as np +from math import prod +from transformers import AutoTokenizer + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora.merge import merge_lora + +from ..models.anima_dit import AnimaDiT +from ..models.z_image_text_encoder import ZImageTextEncoder +from ..models.wan_video_vae import WanVideoVAE + + +class AnimaImagePipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("Z-Image") + self.text_encoder: ZImageTextEncoder = None + self.dit: AnimaDiT = None + self.vae: WanVideoVAE = None + self.tokenizer: AutoTokenizer = None + self.tokenizer_t5xxl: AutoTokenizer = None + self.in_iteration_models = ("dit",) + self.units = [ + AnimaUnit_ShapeChecker(), + AnimaUnit_NoiseInitializer(), + AnimaUnit_InputImageEmbedder(), + AnimaUnit_PromptEmbedder(), + ] + self.model_fn = model_fn_anima + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + tokenizer_t5xxl_config: ModelConfig = ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"), + vram_limit: float = None, + ): + # Initialize pipeline + pipe = AnimaImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("anima_dit") + pipe.vae = model_pool.fetch_model("wan_video_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + if tokenizer_t5xxl_config is not None: + tokenizer_t5xxl_config.download_if_necessary() + pipe.tokenizer_t5xxl = AutoTokenizer.from_pretrained(tokenizer_t5xxl_config.path) + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 4.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + sigma_shift: float = None, + # Progress bar + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"].unsqueeze(2), device=self.device).squeeze(2) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class AnimaUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: AnimaImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + + +class AnimaUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: AnimaImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + + +class AnimaUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: AnimaImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + if isinstance(input_image, list): + input_latents = [] + for image in input_image: + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents.append(pipe.vae.encode(image)) + input_latents = torch.concat(input_latents, dim=0) + else: + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae.encode(image.unsqueeze(2), device=pipe.device).squeeze(2) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +class AnimaUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_emb",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt( + self, + pipe: AnimaImagePipeline, + prompt, + device = None, + max_sequence_length: int = 512, + ): + if isinstance(prompt, str): + prompt = [prompt] + + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = pipe.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-1] + + t5xxl_text_inputs = pipe.tokenizer_t5xxl( + prompt, + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + t5xxl_ids = t5xxl_text_inputs.input_ids.to(device) + + return prompt_embeds.to(pipe.torch_dtype), t5xxl_ids + + def process(self, pipe: AnimaImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds, t5xxl_ids = self.encode_prompt(pipe, prompt, pipe.device) + return {"prompt_emb": prompt_embeds, "t5xxl_ids": t5xxl_ids} + + +def model_fn_anima( + dit: AnimaDiT = None, + latents=None, + timestep=None, + prompt_emb=None, + t5xxl_ids=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs +): + latents = latents.unsqueeze(2) + timestep = timestep / 1000 + model_output = dit( + x=latents, + timesteps=timestep, + context=prompt_emb, + t5xxl_ids=t5xxl_ids, + ) + model_output = model_output.squeeze(2) + return model_output diff --git a/diffsynth/utils/state_dict_converters/anima_dit.py b/diffsynth/utils/state_dict_converters/anima_dit.py new file mode 100644 index 0000000..16afc76 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/anima_dit.py @@ -0,0 +1,6 @@ +def AnimaDiTStateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + value = state_dict[key] + new_state_dict[key.replace("net.", "")] = value + return new_state_dict diff --git a/docs/en/Model_Details/Anima.md b/docs/en/Model_Details/Anima.md new file mode 100644 index 0000000..0f3ae5a --- /dev/null +++ b/docs/en/Model_Details/Anima.md @@ -0,0 +1,139 @@ +# Anima + +Anima is an image generation model trained and open-sourced by CircleStone Labs and Comfy Org. + +## Installation + +Before using this project for model inference and training, please install DiffSynth-Studio first. + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +For more installation information, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md). + +## Quick Start + +The following code demonstrates how to quickly load the [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) model for inference. VRAM management is enabled by default, allowing the framework to automatically control model parameter loading based on available VRAM. Minimum 8GB VRAM required. + +```python +from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = AnimaImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," +image = pipe(prompt, seed=0, num_inference_steps=50) +image.save("image.jpg") +``` + +## Model Overview + +|Model ID|Inference|Low VRAM Inference|Full Training|Post-Full Training Validation|LoRA Training|Post-LoRA Training Validation| +|-|-|-|-|-|-|-| +|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)| + +Special training scripts: + +* Differential LoRA Training: [doc](../Training/Differential_LoRA.md) +* FP8 Precision Training: [doc](../Training/FP8_Precision.md) +* Two-Stage Split Training: [doc](../Training/Split_Training.md) +* End-to-End Direct Distillation: [doc](../Training/Direct_Distill.md) + +## Model Inference + +Models are loaded through `AnimaImagePipeline.from_pretrained`, see [Model Inference](../Pipeline_Usage/Model_Inference.md#loading-models) for details. + +Input parameters for `AnimaImagePipeline` inference include: + +* `prompt`: Text description of the desired image content. +* `negative_prompt`: Content to exclude from the generated image (default: `""`). +* `cfg_scale`: Classifier-free guidance parameter (default: 4.0). +* `input_image`: Input image for image-to-image generation (default: `None`). +* `denoising_strength`: Controls similarity to input image (default: 1.0). +* `height`: Image height (must be multiple of 16, default: 1024). +* `width`: Image width (must be multiple of 16, default: 1024). +* `seed`: Random seed (default: `None`). +* `rand_device`: Device for random noise generation (default: `"cpu"`). +* `num_inference_steps`: Inference steps (default: 30). +* `sigma_shift`: Scheduler sigma offset (default: `None`). +* `progress_bar_cmd`: Progress bar implementation (default: `tqdm.tqdm`). + +For VRAM constraints, enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). Recommended low-VRAM configurations are provided in the "Model Overview" table above. + +## Model Training + +Anima models are trained through [`examples/anima/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/train.py) with parameters including: + +* General Training Parameters + * Dataset Configuration + * `--dataset_base_path`: Dataset root directory. + * `--dataset_metadata_path`: Metadata file path. + * `--dataset_repeat`: Dataset repetition per epoch. + * `--dataset_num_workers`: Dataloader worker count. + * `--data_file_keys`: Metadata fields to load (comma-separated). + * Model Loading + * `--model_paths`: Model paths (JSON format). + * `--model_id_with_origin_paths`: Model IDs with origin paths (e.g., `"anima-team/anima-1B:text_encoder/*.safetensors"`). + * `--extra_inputs`: Additional pipeline inputs (e.g., `controlnet_inputs` for ControlNet). + * `--fp8_models`: FP8-formatted models (same format as `--model_paths`). + * Training Configuration + * `--learning_rate`: Learning rate. + * `--num_epochs`: Training epochs. + * `--trainable_models`: Trainable components (e.g., `dit`, `vae`, `text_encoder`). + * `--find_unused_parameters`: Handle unused parameters in DDP training. + * `--weight_decay`: Weight decay value. + * `--task`: Training task (default: `sft`). + * Output Configuration + * `--output_path`: Model output directory. + * `--remove_prefix_in_ckpt`: Remove state dict prefixes. + * `--save_steps`: Model saving interval. + * LoRA Configuration + * `--lora_base_model`: Target model for LoRA. + * `--lora_target_modules`: Target modules for LoRA. + * `--lora_rank`: LoRA rank. + * `--lora_checkpoint`: LoRA checkpoint path. + * `--preset_lora_path`: Preloaded LoRA checkpoint path. + * `--preset_lora_model`: Model to merge LoRA with (e.g., `dit`). + * Gradient Configuration + * `--use_gradient_checkpointing`: Enable gradient checkpointing. + * `--use_gradient_checkpointing_offload`: Offload checkpointing to CPU. + * `--gradient_accumulation_steps`: Gradient accumulation steps. + * Image Resolution + * `--height`: Image height (empty for dynamic resolution). + * `--width`: Image width (empty for dynamic resolution). + * `--max_pixels`: Maximum pixel area for dynamic resolution. +* Anima-Specific Parameters + * `--tokenizer_path`: Tokenizer path for text-to-image models. + * `--tokenizer_t5xxl_path`: T5-XXL tokenizer path. + +We provide a sample image dataset for testing: + +```shell +modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +``` + +For training script details, refer to [Model Training](../Pipeline_Usage/Model_Training.md). For advanced training techniques, see [Training Framework Documentation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/). \ No newline at end of file diff --git a/docs/en/index.rst b/docs/en/index.rst index ca38620..c4e2736 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -27,6 +27,7 @@ Welcome to DiffSynth-Studio's Documentation Model_Details/Qwen-Image Model_Details/FLUX2 Model_Details/Z-Image + Model_Details/Anima Model_Details/LTX-2 .. toctree:: diff --git a/docs/zh/Model_Details/Anima.md b/docs/zh/Model_Details/Anima.md new file mode 100644 index 0000000..0d5576b --- /dev/null +++ b/docs/zh/Model_Details/Anima.md @@ -0,0 +1,139 @@ +# Anima + +Anima 是由 CircleStone Labs 与 Comfy Org 训练并开源的图像生成模型。 + +## 安装 + +在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。 + +## 快速开始 + +运行以下代码可以快速加载 [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。 + +```python +from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = AnimaImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," +image = pipe(prompt, seed=0, num_inference_steps=50) +image.save("image.jpg") +``` + +## 模型总览 + +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)| + +特殊训练脚本: + +* 差分 LoRA 训练:[doc](../Training/Differential_LoRA.md) +* FP8 精度训练:[doc](../Training/FP8_Precision.md) +* 两阶段拆分训练:[doc](../Training/Split_Training.md) +* 端到端直接蒸馏:[doc](../Training/Direct_Distill.md) + +## 模型推理 + +模型通过 `AnimaImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。 + +`AnimaImagePipeline` 推理的输入参数包括: + +* `prompt`: 提示词,描述画面中出现的内容。 +* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。 +* `cfg_scale`: Classifier-free guidance 的参数,默认值为 4.0。 +* `input_image`: 输入图像,用于图像到图像的生成。默认为 `None`。 +* `denoising_strength`: 去噪强度,控制生成图像与输入图像的相似度,默认值为 1.0。 +* `height`: 图像高度,需保证高度为 16 的倍数,默认值为 1024。 +* `width`: 图像宽度,需保证宽度为 16 的倍数,默认值为 1024。 +* `seed`: 随机种子。默认为 `None`,即完全随机。 +* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 +* `num_inference_steps`: 推理次数,默认值为 30。 +* `sigma_shift`: 调度器的 sigma 偏移量,默认为 `None`。 +* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 + +如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 + +## 模型训练 + +Anima 系列模型统一通过 [`examples/anima/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/train.py) 进行训练,脚本的参数包括: + +* 通用训练参数 + * 数据集基础配置 + * `--dataset_base_path`: 数据集的根目录。 + * `--dataset_metadata_path`: 数据集的元数据文件路径。 + * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 + * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。 + * 模型加载配置 + * `--model_paths`: 要加载的模型路径。JSON 格式。 + * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"anima-team/anima-1B:text_encoder/*.safetensors"`。用逗号分隔。 + * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练 ControlNet 模型时需要额外参数 `controlnet_inputs`,以 `,` 分隔。 + * `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。 + * 训练基础配置 + * `--learning_rate`: 学习率。 + * `--num_epochs`: 轮数(Epoch)。 + * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。 + * `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。 + * `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。 + * 输出配置 + * `--output_path`: 模型保存路径。 + * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。 + * `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。 + * LoRA 配置 + * `--lora_base_model`: LoRA 添加到哪个模型上。 + * `--lora_target_modules`: LoRA 添加到哪些层上。 + * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 + * `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。 + * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。 + * 梯度配置 + * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。 + * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 + * `--gradient_accumulation_steps`: 梯度累积步数。 + * 图像宽高配置(适用于图像生成模型和视频生成模型) + * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。 +* Anima 专有参数 + * `--tokenizer_path`: tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。 + * `--tokenizer_t5xxl_path`: T5-XXL tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。 + +我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集: + +```shell +modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +``` + +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。 diff --git a/docs/zh/index.rst b/docs/zh/index.rst index d2afefc..4ee551a 100644 --- a/docs/zh/index.rst +++ b/docs/zh/index.rst @@ -27,6 +27,7 @@ Model_Details/Qwen-Image Model_Details/FLUX2 Model_Details/Z-Image + Model_Details/Anima Model_Details/LTX-2 .. toctree:: diff --git a/examples/anima/model_inference/anima-preview.py b/examples/anima/model_inference/anima-preview.py new file mode 100644 index 0000000..9440bdf --- /dev/null +++ b/examples/anima/model_inference/anima-preview.py @@ -0,0 +1,19 @@ +from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig +import torch + + +pipe = AnimaImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors"), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors"), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/") +) +prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," +image = pipe(prompt, seed=0, num_inference_steps=50) +image.save("image.jpg") diff --git a/examples/anima/model_inference_low_vram/anima-preview.py b/examples/anima/model_inference_low_vram/anima-preview.py new file mode 100644 index 0000000..bfe8e24 --- /dev/null +++ b/examples/anima/model_inference_low_vram/anima-preview.py @@ -0,0 +1,30 @@ +from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = AnimaImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," +image = pipe(prompt, seed=0, num_inference_steps=50) +image.save("image.jpg") diff --git a/examples/anima/model_training/full/anima-preview.sh b/examples/anima/model_training/full/anima-preview.sh new file mode 100644 index 0000000..58bf844 --- /dev/null +++ b/examples/anima/model_training/full/anima-preview.sh @@ -0,0 +1,14 @@ +accelerate launch examples/anima/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "circlestone-labs/Anima:split_files/diffusion_models/anima-preview.safetensors,circlestone-labs/Anima:split_files/text_encoders/qwen_3_06b_base.safetensors,circlestone-labs/Anima:split_files/vae/qwen_image_vae.safetensors" \ + --tokenizer_path "Qwen/Qwen3-0.6B:./" \ + --tokenizer_t5xxl_path "stabilityai/stable-diffusion-3.5-large:tokenizer_3/" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/anima-preview_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ No newline at end of file diff --git a/examples/anima/model_training/lora/anima-preview.sh b/examples/anima/model_training/lora/anima-preview.sh new file mode 100644 index 0000000..462a844 --- /dev/null +++ b/examples/anima/model_training/lora/anima-preview.sh @@ -0,0 +1,16 @@ +accelerate launch examples/anima/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "circlestone-labs/Anima:split_files/diffusion_models/anima-preview.safetensors,circlestone-labs/Anima:split_files/text_encoders/qwen_3_06b_base.safetensors,circlestone-labs/Anima:split_files/vae/qwen_image_vae.safetensors" \ + --tokenizer_path "Qwen/Qwen3-0.6B:./" \ + --tokenizer_t5xxl_path "stabilityai/stable-diffusion-3.5-large:tokenizer_3/" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/anima-preview_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ No newline at end of file diff --git a/examples/anima/model_training/train.py b/examples/anima/model_training/train.py new file mode 100644 index 0000000..89e7b72 --- /dev/null +++ b/examples/anima/model_training/train.py @@ -0,0 +1,145 @@ +import torch, os, argparse, accelerate +from diffsynth.core import UnifiedDataset +from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig +from diffsynth.diffusion import * +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class AnimaTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_path=None, tokenizer_t5xxl_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + fp8_models=None, + offload_models=None, + device="cpu", + task="sft", + ): + super().__init__() + # Load models + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) + tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./")) + tokenizer_t5xxl_config = self.parse_path_or_model_id(tokenizer_t5xxl_path, ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/")) + self.pipe = AnimaImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, tokenizer_t5xxl_config=tokenizer_t5xxl_config) + self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) + + # Training mode + self.switch_pipe_to_training_mode( + self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, + preset_lora_path, preset_lora_model, + task=task, + ) + + # Other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.fp8_models = fp8_models + self.task = task + self.task_to_loss = { + "sft:data_process": lambda pipe, *args: args, + "direct_distill:data_process": lambda pipe, *args: args, + "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), + "direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), + } + + def get_pipeline_inputs(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + } + inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) + return inputs_shared, inputs_posi, inputs_nega + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.get_pipeline_inputs(data) + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + for unit in self.pipe.units: + inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) + loss = self.task_to_loss[self.task](self.pipe, *inputs) + return loss + + +def anima_parser(): + parser = argparse.ArgumentParser(description="Training script for Anima models.") + parser = add_general_config(parser) + parser = add_image_size_config(parser) + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + parser.add_argument("--tokenizer_t5xxl_path", type=str, default=None, help="Path to tokenizer_t5xxl.") + return parser + + +if __name__ == "__main__": + parser = anima_parser() + args = parser.parse_args() + accelerator = accelerate.Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], + ) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = AnimaTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_path=args.tokenizer_path, + tokenizer_t5xxl_path=args.tokenizer_t5xxl_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + preset_lora_path=args.preset_lora_path, + preset_lora_model=args.preset_lora_model, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + fp8_models=args.fp8_models, + offload_models=args.offload_models, + task=args.task, + device=accelerator.device, + ) + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + ) + launcher_map = { + "sft:data_process": launch_data_process_task, + "direct_distill:data_process": launch_data_process_task, + "sft": launch_training_task, + "sft:train": launch_training_task, + "direct_distill": launch_training_task, + "direct_distill:train": launch_training_task, + } + launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) \ No newline at end of file diff --git a/examples/anima/model_training/validate_full/anima-preview.py b/examples/anima/model_training/validate_full/anima-preview.py new file mode 100644 index 0000000..9f31a5a --- /dev/null +++ b/examples/anima/model_training/validate_full/anima-preview.py @@ -0,0 +1,21 @@ +from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + + +pipe = AnimaImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors"), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors"), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/") +) +state_dict = load_state_dict("./models/train/anima-preview_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0) +image.save("image.jpg") \ No newline at end of file diff --git a/examples/anima/model_training/validate_lora/anima-preview.py b/examples/anima/model_training/validate_lora/anima-preview.py new file mode 100644 index 0000000..df107d2 --- /dev/null +++ b/examples/anima/model_training/validate_lora/anima-preview.py @@ -0,0 +1,19 @@ +from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig +import torch + + +pipe = AnimaImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors"), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors"), + ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/") +) +pipe.load_lora(pipe.dit, "./models/train/anima-preview_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=0) +image.save("image.jpg") \ No newline at end of file From f43b18ec218ecab917e539d69bcd1f06d4f60eb9 Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:59:13 +0800 Subject: [PATCH 19/23] Update docs (#1318) * update docs --- docs/en/Model_Details/Anima.md | 2 +- docs/en/README.md | 2 ++ docs/zh/README.md | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/en/Model_Details/Anima.md b/docs/en/Model_Details/Anima.md index 0f3ae5a..91ecd89 100644 --- a/docs/en/Model_Details/Anima.md +++ b/docs/en/Model_Details/Anima.md @@ -52,7 +52,7 @@ image.save("image.jpg") ## Model Overview -|Model ID|Inference|Low VRAM Inference|Full Training|Post-Full Training Validation|LoRA Training|Post-LoRA Training Validation| +|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training| |-|-|-|-|-|-|-| |[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)| diff --git a/docs/en/README.md b/docs/en/README.md index aac6000..66eabb0 100644 --- a/docs/en/README.md +++ b/docs/en/README.md @@ -42,6 +42,8 @@ This section introduces the Diffusion models supported by `DiffSynth-Studio`. So * [Qwen-Image](./Model_Details/Qwen-Image.md) * [FLUX.2](./Model_Details/FLUX2.md) * [Z-Image](./Model_Details/Z-Image.md) +* [Anima](./Model_Details/Anima.md) +* [LTX-2](./Model_Details/LTX-2.md) ## Section 3: Training Framework diff --git a/docs/zh/README.md b/docs/zh/README.md index 825415e..5f28cde 100644 --- a/docs/zh/README.md +++ b/docs/zh/README.md @@ -42,6 +42,8 @@ graph LR; * [Qwen-Image](./Model_Details/Qwen-Image.md) * [FLUX.2](./Model_Details/FLUX2.md) * [Z-Image](./Model_Details/Z-Image.md) +* [Anima](./Model_Details/Anima.md) +* [LTX-2](./Model_Details/LTX-2.md) ## Section 3: 训练框架 From b3ef224042007698278d77a481071f6cd56aedbf Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:06:55 +0800 Subject: [PATCH 20/23] support Anima gradient checkpointing (#1319) --- diffsynth/pipelines/anima_image.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/diffsynth/pipelines/anima_image.py b/diffsynth/pipelines/anima_image.py index 732ede5..32a3c71 100644 --- a/diffsynth/pipelines/anima_image.py +++ b/diffsynth/pipelines/anima_image.py @@ -256,6 +256,8 @@ def model_fn_anima( timesteps=timestep, context=prompt_emb, t5xxl_ids=t5xxl_ids, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) model_output = model_output.squeeze(2) return model_output From 237d17873336f45267ec6bfe267e56d5b2a4bf7d Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Tue, 3 Mar 2026 11:08:31 +0800 Subject: [PATCH 21/23] Fix LoRA compatibility issues. (#1320) --- diffsynth/utils/lora/general.py | 12 ++++++++++-- .../Qwen-Image-Edit-2511-Lightning.py | 2 +- .../Qwen-Image-Edit-2511-Lightning.py | 2 +- pyproject.toml | 2 +- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/diffsynth/utils/lora/general.py b/diffsynth/utils/lora/general.py index 624549d..85ada77 100644 --- a/diffsynth/utils/lora/general.py +++ b/diffsynth/utils/lora/general.py @@ -1,4 +1,4 @@ -import torch +import torch, warnings class GeneralLoRALoader: @@ -26,7 +26,11 @@ class GeneralLoRALoader: keys.pop(0) keys.pop(-1) target_name = ".".join(keys) - lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key)) + # Alpha: Deprecated but retained for compatibility. + key_alpha = key.replace(lora_B_key + ".weight", "alpha").replace(lora_B_key + ".default.weight", "alpha") + if key_alpha == key or key_alpha not in lora_state_dict: + key_alpha = None + lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key), key_alpha) return lora_name_dict @@ -36,6 +40,10 @@ class GeneralLoRALoader: for name in name_dict: weight_up = state_dict[name_dict[name][0]] weight_down = state_dict[name_dict[name][1]] + if name_dict[name][2] is not None: + warnings.warn("Alpha detected in the LoRA file. This may be a LoRA model not trained by DiffSynth-Studio. To ensure compatibility, the LoRA weights will be converted to weight * alpha / rank.") + alpha = state_dict[name_dict[name][2]] / weight_down.shape[0] + weight_down = weight_down * alpha state_dict_[name + f".lora_B{suffix}"] = weight_up state_dict_[name + f".lora_A{suffix}"] = weight_down return state_dict_ diff --git a/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py b/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py index 098a77c..c30ccba 100644 --- a/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py +++ b/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py @@ -18,7 +18,7 @@ lora = ModelConfig( model_id="lightx2v/Qwen-Image-Edit-2511-Lightning", origin_file_pattern="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors" ) -pipe.load_lora(pipe.dit, lora, alpha=8/64) +pipe.load_lora(pipe.dit, lora, alpha=1) pipe.scheduler = FlowMatchScheduler("Qwen-Image-Lightning") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py index cbe43a2..cd5b4f3 100644 --- a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py @@ -28,7 +28,7 @@ lora = ModelConfig( model_id="lightx2v/Qwen-Image-Edit-2511-Lightning", origin_file_pattern="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors" ) -pipe.load_lora(pipe.dit, lora, alpha=8/64) +pipe.load_lora(pipe.dit, lora, alpha=1) pipe.scheduler = FlowMatchScheduler("Qwen-Image-Lightning") diff --git a/pyproject.toml b/pyproject.toml index 9a5075b..c6ba767 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "diffsynth" -version = "2.0.4" +version = "2.0.5" description = "Enjoy the magic of Diffusion models!" authors = [{name = "ModelScope Team"}] license = {text = "Apache-2.0"} From 62ba8a3f2e129ae4d31c36e766a10d5fd13ca092 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 3 Mar 2026 12:44:36 +0800 Subject: [PATCH 22/23] fix qwen_text_encoder bug in transformers>=5.2.0 --- diffsynth/configs/vram_management_module_maps.py | 14 ++++++++++++++ diffsynth/models/model_loader.py | 5 +++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index d86f5fa..902c38b 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -250,3 +250,17 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", }, } + +def QwenImageTextEncoder_Module_Map_Updater(): + current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"] + from packaging import version + import transformers + if version.parse(transformers.__version__) >= version.parse("5.2.0"): + # The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly + current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None) + current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule" + return current + +VERSION_CHECKER_MAPS = { + "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater, +} \ No newline at end of file diff --git a/diffsynth/models/model_loader.py b/diffsynth/models/model_loader.py index 6a58c89..7a716e2 100644 --- a/diffsynth/models/model_loader.py +++ b/diffsynth/models/model_loader.py @@ -1,6 +1,6 @@ from ..core.loader import load_model, hash_model_file from ..core.vram import AutoWrappedModule -from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS +from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS import importlib, json, torch @@ -22,7 +22,8 @@ class ModelPool: def fetch_module_map(self, model_class, vram_config): if self.need_to_enable_vram_management(vram_config): if model_class in VRAM_MANAGEMENT_MODULE_MAPS: - module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()} + vram_module_map = VRAM_MANAGEMENT_MODULE_MAPS[model_class] if model_class not in VERSION_CHECKER_MAPS else VERSION_CHECKER_MAPS[model_class]() + module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in vram_module_map.items()} else: module_map = {self.import_model_class(model_class): AutoWrappedModule} else: From add6f88324f19a1487089823d8d5345698def298 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 3 Mar 2026 15:33:42 +0800 Subject: [PATCH 23/23] bugfix --- diffsynth/configs/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/configs/__init__.py b/diffsynth/configs/__init__.py index 144a822..7ad5b73 100644 --- a/diffsynth/configs/__init__.py +++ b/diffsynth/configs/__init__.py @@ -1,2 +1,2 @@ from .model_configs import MODEL_CONFIGS -from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS +from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS