diff --git a/diffsynth/models/ltx2_dit.py b/diffsynth/models/ltx2_dit.py index c12a3f3..cc584ee 100644 --- a/diffsynth/models/ltx2_dit.py +++ b/diffsynth/models/ltx2_dit.py @@ -8,6 +8,7 @@ import torch from einops import rearrange from .ltx2_common import rms_norm, Modality from ..core.attention.attention import attention_forward +from ..core import gradient_checkpoint_forward def get_timestep_embedding( @@ -1352,28 +1353,21 @@ class LTXModel(torch.nn.Module): video: TransformerArgs | None, audio: TransformerArgs | None, perturbations: BatchedPerturbationConfig, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, ) -> tuple[TransformerArgs, TransformerArgs]: """Process transformer blocks for LTXAV.""" # Process transformer blocks for block in self.transformer_blocks: - if self._enable_gradient_checkpointing and self.training: - # Use gradient checkpointing to save memory during training. - # With use_reentrant=False, we can pass dataclasses directly - - # PyTorch will track all tensor leaves in the computation graph. - video, audio = torch.utils.checkpoint.checkpoint( - block, - video, - audio, - perturbations, - use_reentrant=False, - ) - else: - video, audio = block( - video=video, - audio=audio, - perturbations=perturbations, - ) + video, audio = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + video=video, + audio=audio, + perturbations=perturbations, + ) return video, audio @@ -1398,7 +1392,12 @@ class LTXModel(torch.nn.Module): return x def _forward( - self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig + self, + video: Modality | None, + audio: Modality | None, + perturbations: BatchedPerturbationConfig, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for LTX models. @@ -1417,6 +1416,8 @@ class LTXModel(torch.nn.Module): video=video_args, audio=audio_args, perturbations=perturbations, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) # Process output @@ -1440,12 +1441,12 @@ class LTXModel(torch.nn.Module): ) return vx, ax - def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps): + def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False): cross_pe_max_pos = None if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) self._init_preprocessors(cross_pe_max_pos) video = Modality(video_latents, video_timesteps, video_positions, video_context) audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None - vx, ax = self._forward(video=video, audio=audio, perturbations=None) + vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload) return vx, ax diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 4587449..fc0b969 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -577,6 +577,8 @@ def model_fn_ltx2( audio_positions=audio_positions, audio_context=audio_context, audio_timesteps=audio_timesteps, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) # unpatchify vx = video_patchifier.unpatchify_video(vx, f, h, w) diff --git a/diffsynth/utils/data/__init__.py b/diffsynth/utils/data/__init__.py index c6b9daa..edc3d41 100644 --- a/diffsynth/utils/data/__init__.py +++ b/diffsynth/utils/data/__init__.py @@ -116,7 +116,7 @@ class VideoData: if self.height is not None and self.width is not None: return self.height, self.width else: - height, width, _ = self.__getitem__(0).shape + width, height = self.__getitem__(0).size return height, width def __getitem__(self, item): diff --git a/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh b/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh index 4aec5f8..04f3b1c 100644 --- a/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh +++ b/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh @@ -6,7 +6,7 @@ accelerate launch examples/ltx2/model_training/train.py \ --extra_inputs "input_audio" \ --height 512 \ --width 768 \ - --num_frames 49 \ + --num_frames 121 \ --dataset_repeat 1 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ --learning_rate 1e-4 \ @@ -23,7 +23,7 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera --extra_inputs "input_audio" \ --height 512 \ --width 768 \ - --num_frames 49 \ + --num_frames 121 \ --dataset_repeat 100 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \ --learning_rate 1e-4 \ diff --git a/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh b/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh index b2a5609..f7362af 100644 --- a/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh +++ b/examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh @@ -24,7 +24,7 @@ accelerate launch examples/ltx2/model_training/train.py \ --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ --height 512 \ --width 768 \ - --num_frames 49\ + --num_frames 121\ --dataset_repeat 1 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ --learning_rate 1e-4 \ @@ -42,7 +42,7 @@ accelerate launch examples/ltx2/model_training/train.py \ --dataset_base_path ./models/train/LTX2-T2AV-noaudio_lora-splited-cache \ --height 512 \ --width 768 \ - --num_frames 49\ + --num_frames 121\ --dataset_repeat 100 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \ --learning_rate 1e-4 \ diff --git a/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh b/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh index 40dae1a..ebee83d 100644 --- a/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh +++ b/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh @@ -27,7 +27,7 @@ accelerate launch examples/ltx2/model_training/train.py \ --extra_inputs "input_audio" \ --height 512 \ --width 768 \ - --num_frames 49 \ + --num_frames 121 \ --dataset_repeat 1 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ --learning_rate 1e-4 \ @@ -46,7 +46,7 @@ accelerate launch examples/ltx2/model_training/train.py \ --extra_inputs "input_audio" \ --height 512 \ --width 768 \ - --num_frames 49 \ + --num_frames 121 \ --dataset_repeat 100 \ --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \ --learning_rate 1e-4 \ diff --git a/examples/ltx2/model_training/train.py b/examples/ltx2/model_training/train.py index a994f7a..26a2925 100644 --- a/examples/ltx2/model_training/train.py +++ b/examples/ltx2/model_training/train.py @@ -118,10 +118,10 @@ if __name__ == "__main__": max_pixels=args.max_pixels, height=args.height, width=args.width, - height_division_factor=16, - width_division_factor=16, + height_division_factor=32, + width_division_factor=32, num_frames=args.num_frames, - time_division_factor=4, + time_division_factor=8, time_division_remainder=1, ), special_operator_map={ diff --git a/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py b/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py index a5da12d..6201ec1 100644 --- a/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py +++ b/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py @@ -27,7 +27,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( ) prompt = "A beautiful sunset over the ocean." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -height, width, num_frames = 512, 768, 49 +height, width, num_frames = 512, 768, 121 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, diff --git a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py index 471a901..d0dab81 100644 --- a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py +++ b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py @@ -28,7 +28,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( pipe.load_lora(pipe.dit, "models/train/LTX2-T2AV_lora/epoch-4.safetensors") prompt = "A beautiful sunset over the ocean." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -height, width, num_frames = 512, 768, 49 +height, width, num_frames = 512, 768, 121 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, diff --git a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py index 4c2bccc..336b2bf 100644 --- a/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py +++ b/examples/ltx2/model_training/validate_lora/LTX-2-T2AV_noaudio.py @@ -28,7 +28,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( pipe.load_lora(pipe.dit, "models/train/LTX2-T2AV-noaudio_lora/epoch-4.safetensors") prompt = "A beautiful sunset over the ocean." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -height, width, num_frames = 512, 768, 49 +height, width, num_frames = 512, 768, 121 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt,