support ltx-2 training

This commit is contained in:
mi804
2026-02-25 17:19:57 +08:00
parent 288bbc7128
commit 586ac9d8a6
40 changed files with 893 additions and 49 deletions

View File

@@ -1446,6 +1446,6 @@ class LTXModel(torch.nn.Module):
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
self._init_preprocessors(cross_pe_max_pos)
video = Modality(video_latents, video_timesteps, video_positions, video_context)
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
return vx, ax