support ltx2 gradient_checkpointing

This commit is contained in:
mi804
2026-02-26 19:19:59 +08:00
parent a87910bc65
commit a18966c300
10 changed files with 36 additions and 33 deletions

View File

@@ -8,6 +8,7 @@ import torch
from einops import rearrange
from .ltx2_common import rms_norm, Modality
from ..core.attention.attention import attention_forward
from ..core import gradient_checkpoint_forward
def get_timestep_embedding(
@@ -1352,28 +1353,21 @@ class LTXModel(torch.nn.Module):
video: TransformerArgs | None,
audio: TransformerArgs | None,
perturbations: BatchedPerturbationConfig,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
) -> tuple[TransformerArgs, TransformerArgs]:
"""Process transformer blocks for LTXAV."""
# Process transformer blocks
for block in self.transformer_blocks:
if self._enable_gradient_checkpointing and self.training:
# Use gradient checkpointing to save memory during training.
# With use_reentrant=False, we can pass dataclasses directly -
# PyTorch will track all tensor leaves in the computation graph.
video, audio = torch.utils.checkpoint.checkpoint(
block,
video,
audio,
perturbations,
use_reentrant=False,
)
else:
video, audio = block(
video=video,
audio=audio,
perturbations=perturbations,
)
video, audio = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
video=video,
audio=audio,
perturbations=perturbations,
)
return video, audio
@@ -1398,7 +1392,12 @@ class LTXModel(torch.nn.Module):
return x
def _forward(
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
self,
video: Modality | None,
audio: Modality | None,
perturbations: BatchedPerturbationConfig,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for LTX models.
@@ -1417,6 +1416,8 @@ class LTXModel(torch.nn.Module):
video=video_args,
audio=audio_args,
perturbations=perturbations,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
# Process output
@@ -1440,12 +1441,12 @@ class LTXModel(torch.nn.Module):
)
return vx, ax
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
cross_pe_max_pos = None
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
self._init_preprocessors(cross_pe_max_pos)
video = Modality(video_latents, video_timesteps, video_positions, video_context)
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload)
return vx, ax