mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
support ltx2 one-stage pipeline
This commit is contained in:
@@ -514,7 +514,7 @@ class Attention(torch.nn.Module):
|
||||
out_pattern="b s n d",
|
||||
attn_mask=mask
|
||||
)
|
||||
|
||||
|
||||
# Reshape back to original format
|
||||
out = out.flatten(2, 3)
|
||||
return self.to_out(out)
|
||||
@@ -1398,7 +1398,7 @@ class LTXModel(torch.nn.Module):
|
||||
x = proj_out(x)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
def _forward(
|
||||
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -1440,3 +1440,9 @@ class LTXModel(torch.nn.Module):
|
||||
else None
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
|
||||
video = Modality(video_latents, video_timesteps, video_positions, video_context)
|
||||
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
|
||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
|
||||
return vx, ax
|
||||
|
||||
Reference in New Issue
Block a user