support ltx2 one-stage pipeline

This commit is contained in:
mi804
2026-01-29 16:30:15 +08:00
parent 8d303b47e9
commit b1a2782ad7
7 changed files with 1005 additions and 7 deletions

View File

@@ -514,7 +514,7 @@ class Attention(torch.nn.Module):
out_pattern="b s n d",
attn_mask=mask
)
# Reshape back to original format
out = out.flatten(2, 3)
return self.to_out(out)
@@ -1398,7 +1398,7 @@ class LTXModel(torch.nn.Module):
x = proj_out(x)
return x
def forward(
def _forward(
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@@ -1440,3 +1440,9 @@ class LTXModel(torch.nn.Module):
else None
)
return vx, ax
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
video = Modality(video_latents, video_timesteps, video_positions, video_context)
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
return vx, ax