support ltx2.3 inference (#1332)

This commit is contained in:
Zhongjie Duan
2026-03-06 16:24:53 +08:00
committed by GitHub
17 changed files with 1608 additions and 351 deletions

View File

@@ -92,13 +92,13 @@ class LTX2AudioVideoPipeline(BasePipeline):
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
# Stage 2
if stage2_lora_config is not None:
stage2_lora_config.download_if_necessary()
pipe.stage2_lora_path = stage2_lora_config.path
# Optional, currently not used
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
@@ -168,8 +168,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
# Shape
height: Optional[int] = 512,
width: Optional[int] = 768,
num_frames=121,
frame_rate=24,
num_frames: Optional[int] = 121,
frame_rate: Optional[int] = 24,
# Classifier-free guidance
cfg_scale: Optional[float] = 3.0,
# Scheduler
@@ -238,7 +238,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
return video, decoded_audio
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None, num_frames=121):
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None):
b, _, f, h, w = latents.shape
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
@@ -306,121 +306,20 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
output_params=("video_context", "audio_context"),
onload_model_names=("text_encoder", "text_encoder_post_modules"),
)
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return (attention_mask - 1).to(dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
encoded_input,
connector_attention_mask,
)
# restore the mask values to int64
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
encoded = encoded * attention_mask
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
encoded_input, connector_attention_mask)
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
def _norm_and_concat_padded_batch(
self,
encoded_text: torch.Tensor,
sequence_lengths: torch.Tensor,
padding_side: str = "right",
) -> torch.Tensor:
"""Normalize and flatten multi-layer hidden states, respecting padding.
Performs per-batch, per-layer normalization using masked mean and range,
then concatenates across the layer dimension.
Args:
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
sequence_lengths: Number of valid (non-padded) tokens per batch item.
padding_side: Whether padding is on "left" or "right".
Returns:
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
with padded positions zeroed out.
"""
b, t, d, l = encoded_text.shape # noqa: E741
device = encoded_text.device
# Build mask: [B, T, 1, 1]
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [B, T]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = t - sequence_lengths[:, None] # [B, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = rearrange(mask, "b t -> b t 1 1")
eps = 1e-6
# Compute masked mean: [B, 1, 1, L]
masked = encoded_text.masked_fill(~mask, 0.0)
denom = (sequence_lengths * d).view(b, 1, 1, 1)
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
# Compute masked min/max: [B, 1, 1, L]
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
range_ = x_max - x_min
# Normalize only the valid tokens
normed = 8 * (encoded_text - mean) / (range_ + eps)
# concat to be [Batch, T, D * L] - this preserves the original structure
normed = normed.reshape(b, t, -1) # [B, T, D * L]
# Apply mask to preserve original padding (set padded positions to 0)
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
normed = normed.masked_fill(~mask_flattened, 0.0)
return normed
def _run_feature_extractor(self,
pipe,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
padding_side: str = "right") -> torch.Tensor:
encoded_text_features = torch.stack(hidden_states, dim=-1)
encoded_text_features_dtype = encoded_text_features.dtype
sequence_lengths = attention_mask.sum(dim=-1)
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
sequence_lengths,
padding_side=padding_side)
return pipe.text_encoder_post_modules.feature_extractor_linear(
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
def _preprocess_text(
self,
pipe,
text: str,
padding_side: str = "left",
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""
Encode a given string into feature tensors suitable for downstream tasks.
Args:
text (str): Input string to encode.
Returns:
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
"""
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
projected = self._run_feature_extractor(pipe,
hidden_states=outputs.hidden_states,
attention_mask=attention_mask,
padding_side=padding_side)
return projected, attention_mask
return outputs.hidden_states, attention_mask
def encode_prompt(self, pipe, text, padding_side="left"):
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
hidden_states, attention_mask = self._preprocess_text(pipe, text)
video_encoding, audio_encoding, attention_mask = pipe.text_encoder_post_modules.process_hidden_states(
hidden_states, attention_mask, padding_side)
return video_encoding, audio_encoding, attention_mask
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
@@ -539,7 +438,7 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
tile_overlap_in_pixels) for img in input_images
]
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength, num_frames=num_frames)
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength)
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
if use_two_stage_pipeline:
stage2_latents = [
@@ -649,6 +548,7 @@ def model_fn_ltx2(
audio_positions=audio_positions,
audio_context=audio_context,
audio_timesteps=audio_timesteps,
sigma=timestep,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)