mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ltx2.3 inference (#1332)
This commit is contained in:
@@ -92,13 +92,13 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
|
||||
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
|
||||
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
||||
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||
|
||||
# Stage 2
|
||||
if stage2_lora_config is not None:
|
||||
stage2_lora_config.download_if_necessary()
|
||||
pipe.stage2_lora_path = stage2_lora_config.path
|
||||
# Optional, currently not used
|
||||
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
@@ -168,8 +168,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
# Shape
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 768,
|
||||
num_frames=121,
|
||||
frame_rate=24,
|
||||
num_frames: Optional[int] = 121,
|
||||
frame_rate: Optional[int] = 24,
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 3.0,
|
||||
# Scheduler
|
||||
@@ -238,7 +238,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
|
||||
return video, decoded_audio
|
||||
|
||||
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None, num_frames=121):
|
||||
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None):
|
||||
b, _, f, h, w = latents.shape
|
||||
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
|
||||
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
|
||||
@@ -306,121 +306,20 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
output_params=("video_context", "audio_context"),
|
||||
onload_model_names=("text_encoder", "text_encoder_post_modules"),
|
||||
)
|
||||
|
||||
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (attention_mask - 1).to(dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
|
||||
|
||||
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
|
||||
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
|
||||
|
||||
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
|
||||
encoded_input,
|
||||
connector_attention_mask,
|
||||
)
|
||||
|
||||
# restore the mask values to int64
|
||||
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
|
||||
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
|
||||
encoded = encoded * attention_mask
|
||||
|
||||
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
|
||||
encoded_input, connector_attention_mask)
|
||||
|
||||
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
|
||||
|
||||
def _norm_and_concat_padded_batch(
|
||||
self,
|
||||
encoded_text: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> torch.Tensor:
|
||||
"""Normalize and flatten multi-layer hidden states, respecting padding.
|
||||
Performs per-batch, per-layer normalization using masked mean and range,
|
||||
then concatenates across the layer dimension.
|
||||
Args:
|
||||
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
|
||||
sequence_lengths: Number of valid (non-padded) tokens per batch item.
|
||||
padding_side: Whether padding is on "left" or "right".
|
||||
Returns:
|
||||
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
|
||||
with padded positions zeroed out.
|
||||
"""
|
||||
b, t, d, l = encoded_text.shape # noqa: E741
|
||||
device = encoded_text.device
|
||||
# Build mask: [B, T, 1, 1]
|
||||
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [B, T]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = t - sequence_lengths[:, None] # [B, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = rearrange(mask, "b t -> b t 1 1")
|
||||
eps = 1e-6
|
||||
# Compute masked mean: [B, 1, 1, L]
|
||||
masked = encoded_text.masked_fill(~mask, 0.0)
|
||||
denom = (sequence_lengths * d).view(b, 1, 1, 1)
|
||||
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
|
||||
# Compute masked min/max: [B, 1, 1, L]
|
||||
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
range_ = x_max - x_min
|
||||
# Normalize only the valid tokens
|
||||
normed = 8 * (encoded_text - mean) / (range_ + eps)
|
||||
# concat to be [Batch, T, D * L] - this preserves the original structure
|
||||
normed = normed.reshape(b, t, -1) # [B, T, D * L]
|
||||
# Apply mask to preserve original padding (set padded positions to 0)
|
||||
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
|
||||
normed = normed.masked_fill(~mask_flattened, 0.0)
|
||||
|
||||
return normed
|
||||
|
||||
def _run_feature_extractor(self,
|
||||
pipe,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "right") -> torch.Tensor:
|
||||
encoded_text_features = torch.stack(hidden_states, dim=-1)
|
||||
encoded_text_features_dtype = encoded_text_features.dtype
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
|
||||
sequence_lengths,
|
||||
padding_side=padding_side)
|
||||
|
||||
return pipe.text_encoder_post_modules.feature_extractor_linear(
|
||||
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
|
||||
|
||||
def _preprocess_text(
|
||||
self,
|
||||
pipe,
|
||||
text: str,
|
||||
padding_side: str = "left",
|
||||
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Encode a given string into feature tensors suitable for downstream tasks.
|
||||
Args:
|
||||
text (str): Input string to encode.
|
||||
Returns:
|
||||
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
|
||||
"""
|
||||
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
|
||||
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
|
||||
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
|
||||
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||
projected = self._run_feature_extractor(pipe,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
padding_side=padding_side)
|
||||
return projected, attention_mask
|
||||
|
||||
return outputs.hidden_states, attention_mask
|
||||
def encode_prompt(self, pipe, text, padding_side="left"):
|
||||
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
|
||||
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
|
||||
hidden_states, attention_mask = self._preprocess_text(pipe, text)
|
||||
video_encoding, audio_encoding, attention_mask = pipe.text_encoder_post_modules.process_hidden_states(
|
||||
hidden_states, attention_mask, padding_side)
|
||||
return video_encoding, audio_encoding, attention_mask
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
|
||||
@@ -539,7 +438,7 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
|
||||
tile_overlap_in_pixels) for img in input_images
|
||||
]
|
||||
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength, num_frames=num_frames)
|
||||
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength)
|
||||
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
|
||||
if use_two_stage_pipeline:
|
||||
stage2_latents = [
|
||||
@@ -649,6 +548,7 @@ def model_fn_ltx2(
|
||||
audio_positions=audio_positions,
|
||||
audio_context=audio_context,
|
||||
audio_timesteps=audio_timesteps,
|
||||
sigma=timestep,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user