mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ltx-2 t2v and i2v
This commit is contained in:
@@ -1442,6 +1442,10 @@ class LTXModel(torch.nn.Module):
|
||||
return vx, ax
|
||||
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
|
||||
cross_pe_max_pos = None
|
||||
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
|
||||
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
||||
self._init_preprocessors(cross_pe_max_pos)
|
||||
video = Modality(video_latents, video_timesteps, video_positions, video_context)
|
||||
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
|
||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
|
||||
|
||||
@@ -1648,11 +1648,8 @@ class LTX2VideoEncoder(nn.Module):
|
||||
tile_overlap_in_pixels: Optional[int] = 128,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
device = next(self.parameters()).device
|
||||
vae_dtype = next(self.parameters()).dtype
|
||||
if video.ndim == 4:
|
||||
video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W]
|
||||
video = video.to(device=device, dtype=vae_dtype)
|
||||
# Choose encoding method based on tiling flag
|
||||
if tiled:
|
||||
latents = self.tiled_encode_video(
|
||||
|
||||
Reference in New Issue
Block a user