mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ltx2 gradient_checkpointing
This commit is contained in:
@@ -8,6 +8,7 @@ import torch
|
||||
from einops import rearrange
|
||||
from .ltx2_common import rms_norm, Modality
|
||||
from ..core.attention.attention import attention_forward
|
||||
from ..core import gradient_checkpoint_forward
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -1352,28 +1353,21 @@ class LTXModel(torch.nn.Module):
|
||||
video: TransformerArgs | None,
|
||||
audio: TransformerArgs | None,
|
||||
perturbations: BatchedPerturbationConfig,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
) -> tuple[TransformerArgs, TransformerArgs]:
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
|
||||
# Process transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if self._enable_gradient_checkpointing and self.training:
|
||||
# Use gradient checkpointing to save memory during training.
|
||||
# With use_reentrant=False, we can pass dataclasses directly -
|
||||
# PyTorch will track all tensor leaves in the computation graph.
|
||||
video, audio = torch.utils.checkpoint.checkpoint(
|
||||
block,
|
||||
video,
|
||||
audio,
|
||||
perturbations,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
video, audio = block(
|
||||
video=video,
|
||||
audio=audio,
|
||||
perturbations=perturbations,
|
||||
)
|
||||
video, audio = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
video=video,
|
||||
audio=audio,
|
||||
perturbations=perturbations,
|
||||
)
|
||||
|
||||
return video, audio
|
||||
|
||||
@@ -1398,7 +1392,12 @@ class LTXModel(torch.nn.Module):
|
||||
return x
|
||||
|
||||
def _forward(
|
||||
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
|
||||
self,
|
||||
video: Modality | None,
|
||||
audio: Modality | None,
|
||||
perturbations: BatchedPerturbationConfig,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass for LTX models.
|
||||
@@ -1417,6 +1416,8 @@ class LTXModel(torch.nn.Module):
|
||||
video=video_args,
|
||||
audio=audio_args,
|
||||
perturbations=perturbations,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
# Process output
|
||||
@@ -1440,12 +1441,12 @@ class LTXModel(torch.nn.Module):
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
|
||||
cross_pe_max_pos = None
|
||||
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
|
||||
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
||||
self._init_preprocessors(cross_pe_max_pos)
|
||||
video = Modality(video_latents, video_timesteps, video_positions, video_context)
|
||||
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
|
||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
|
||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload)
|
||||
return vx, ax
|
||||
|
||||
Reference in New Issue
Block a user