mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ltx2.3 inference
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
|
||||
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
|
||||
FeedForward)
|
||||
@@ -147,14 +150,14 @@ class LTXVGemmaTokenizer:
|
||||
return out
|
||||
|
||||
|
||||
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
||||
class GemmaFeaturesExtractorProjLinear(nn.Module):
|
||||
"""
|
||||
Feature extractor module for Gemma models.
|
||||
This module applies a single linear projection to the input tensor.
|
||||
It expects a flattened feature tensor of shape (batch_size, 3840*49).
|
||||
The linear layer maps this to a (batch_size, 3840) embedding.
|
||||
Attributes:
|
||||
aggregate_embed (torch.nn.Linear): Linear projection layer.
|
||||
aggregate_embed (nn.Linear): Linear projection layer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -163,26 +166,65 @@ class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
||||
The input dimension is expected to be 3840 * 49, and the output is 3840.
|
||||
"""
|
||||
super().__init__()
|
||||
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
|
||||
self.aggregate_embed = nn.Linear(3840 * 49, 3840, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the feature extractor.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (batch_size, 3840).
|
||||
"""
|
||||
return self.aggregate_embed(x)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "left",
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states
|
||||
dtype = encoded.dtype
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
normed = _norm_and_concat_padded_batch(encoded, sequence_lengths, padding_side)
|
||||
features = self.aggregate_embed(normed.to(dtype))
|
||||
return features, features
|
||||
|
||||
|
||||
class _BasicTransformerBlock1D(torch.nn.Module):
|
||||
class GemmaSeperatedFeaturesExtractorProjLinear(nn.Module):
|
||||
"""22B: per-token RMS norm → rescale → dual aggregate embeds"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
embedding_dim: int,
|
||||
video_inner_dim: int,
|
||||
audio_inner_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
in_dim = embedding_dim * num_layers
|
||||
self.video_aggregate_embed = torch.nn.Linear(in_dim, video_inner_dim, bias=True)
|
||||
self.audio_aggregate_embed = torch.nn.Linear(in_dim, audio_inner_dim, bias=True)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "left", # noqa: ARG002
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states
|
||||
normed = norm_and_concat_per_token_rms(encoded, attention_mask)
|
||||
normed = normed.to(encoded.dtype)
|
||||
v_dim = self.video_aggregate_embed.out_features
|
||||
video = self.video_aggregate_embed(_rescale_norm(normed, v_dim, self.embedding_dim))
|
||||
audio = None
|
||||
if self.audio_aggregate_embed is not None:
|
||||
a_dim = self.audio_aggregate_embed.out_features
|
||||
audio = self.audio_aggregate_embed(_rescale_norm(normed, a_dim, self.embedding_dim))
|
||||
return video, audio
|
||||
|
||||
|
||||
|
||||
class _BasicTransformerBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
apply_gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -191,6 +233,7 @@ class _BasicTransformerBlock1D(torch.nn.Module):
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(
|
||||
@@ -231,7 +274,7 @@ class _BasicTransformerBlock1D(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Embeddings1DConnector(torch.nn.Module):
|
||||
class Embeddings1DConnector(nn.Module):
|
||||
"""
|
||||
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
|
||||
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
|
||||
@@ -263,6 +306,7 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
num_learnable_registers: int | None = 128,
|
||||
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||
double_precision_rope: bool = True,
|
||||
apply_gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -274,13 +318,14 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
)
|
||||
self.rope_type = rope_type
|
||||
self.double_precision_rope = double_precision_rope
|
||||
self.transformer_1d_blocks = torch.nn.ModuleList(
|
||||
self.transformer_1d_blocks = nn.ModuleList(
|
||||
[
|
||||
_BasicTransformerBlock1D(
|
||||
dim=self.inner_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -288,7 +333,7 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = torch.nn.Parameter(
|
||||
self.learnable_registers = nn.Parameter(
|
||||
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
|
||||
)
|
||||
|
||||
@@ -307,7 +352,7 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
|
||||
non_zero_nums = non_zero_hidden_states.shape[1]
|
||||
pad_length = hidden_states.shape[1] - non_zero_nums
|
||||
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||
adjusted_hidden_states = nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
|
||||
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
|
||||
|
||||
@@ -358,9 +403,147 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2TextEncoderPostModules(torch.nn.Module):
|
||||
def __init__(self,):
|
||||
class LTX2TextEncoderPostModules(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
seperated_audio_video: bool = False,
|
||||
embedding_dim_gemma: int = 3840,
|
||||
num_layers_gemma: int = 49,
|
||||
video_attetion_heads: int = 32,
|
||||
video_attention_head_dim: int = 128,
|
||||
audio_attention_heads: int = 32,
|
||||
audio_attention_head_dim: int = 64,
|
||||
num_connetor_layers: int = 2,
|
||||
apply_gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||
self.embeddings_connector = Embeddings1DConnector()
|
||||
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||
if not seperated_audio_video:
|
||||
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||
self.embeddings_connector = Embeddings1DConnector()
|
||||
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||
else:
|
||||
# LTX-2.3
|
||||
self.feature_extractor_linear = GemmaSeperatedFeaturesExtractorProjLinear(
|
||||
num_layers_gemma, embedding_dim_gemma, video_attetion_heads * video_attention_head_dim,
|
||||
audio_attention_heads * audio_attention_head_dim)
|
||||
self.embeddings_connector = Embeddings1DConnector(
|
||||
attention_head_dim=video_attention_head_dim,
|
||||
num_attention_heads=video_attetion_heads,
|
||||
num_layers=num_connetor_layers,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
attention_head_dim=audio_attention_head_dim,
|
||||
num_attention_heads=audio_attention_heads,
|
||||
num_layers=num_connetor_layers,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
|
||||
def create_embeddings(
|
||||
self,
|
||||
video_features: torch.Tensor,
|
||||
audio_features: torch.Tensor | None,
|
||||
additive_attention_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
|
||||
video_encoded, video_mask = self.embeddings_connector(video_features, additive_attention_mask)
|
||||
video_encoded, binary_mask = _to_binary_mask(video_encoded, video_mask)
|
||||
audio_encoded, _ = self.audio_embeddings_connector(audio_features, additive_attention_mask)
|
||||
|
||||
return video_encoded, audio_encoded, binary_mask
|
||||
|
||||
def process_hidden_states(
|
||||
self,
|
||||
hidden_states: tuple[torch.Tensor, ...],
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "left",
|
||||
):
|
||||
video_feats, audio_feats = self.feature_extractor_linear(hidden_states, attention_mask, padding_side)
|
||||
additive_mask = _convert_to_additive_mask(attention_mask, video_feats.dtype)
|
||||
video_enc, audio_enc, binary_mask = self.create_embeddings(video_feats, audio_feats, additive_mask)
|
||||
return video_enc, audio_enc, binary_mask
|
||||
|
||||
|
||||
def _norm_and_concat_padded_batch(
|
||||
encoded_text: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> torch.Tensor:
|
||||
"""Normalize and flatten multi-layer hidden states, respecting padding.
|
||||
Performs per-batch, per-layer normalization using masked mean and range,
|
||||
then concatenates across the layer dimension.
|
||||
Args:
|
||||
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
|
||||
sequence_lengths: Number of valid (non-padded) tokens per batch item.
|
||||
padding_side: Whether padding is on "left" or "right".
|
||||
Returns:
|
||||
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
|
||||
with padded positions zeroed out.
|
||||
"""
|
||||
b, t, d, l = encoded_text.shape # noqa: E741
|
||||
device = encoded_text.device
|
||||
# Build mask: [B, T, 1, 1]
|
||||
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [B, T]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = t - sequence_lengths[:, None] # [B, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = rearrange(mask, "b t -> b t 1 1")
|
||||
eps = 1e-6
|
||||
# Compute masked mean: [B, 1, 1, L]
|
||||
masked = encoded_text.masked_fill(~mask, 0.0)
|
||||
denom = (sequence_lengths * d).view(b, 1, 1, 1)
|
||||
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
|
||||
# Compute masked min/max: [B, 1, 1, L]
|
||||
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
range_ = x_max - x_min
|
||||
# Normalize only the valid tokens
|
||||
normed = 8 * (encoded_text - mean) / (range_ + eps)
|
||||
# concat to be [Batch, T, D * L] - this preserves the original structure
|
||||
normed = normed.reshape(b, t, -1) # [B, T, D * L]
|
||||
# Apply mask to preserve original padding (set padded positions to 0)
|
||||
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
|
||||
normed = normed.masked_fill(~mask_flattened, 0.0)
|
||||
|
||||
return normed
|
||||
|
||||
|
||||
def _convert_to_additive_mask(attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (attention_mask - 1).to(dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
|
||||
|
||||
def _to_binary_mask(encoded: torch.Tensor, encoded_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert connector output mask to binary mask and apply to encoded tensor."""
|
||||
binary_mask = (encoded_mask < 0.000001).to(torch.int64)
|
||||
binary_mask = binary_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
|
||||
encoded = encoded * binary_mask
|
||||
return encoded, binary_mask
|
||||
|
||||
|
||||
def norm_and_concat_per_token_rms(
|
||||
encoded_text: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Per-token RMSNorm normalization for V2 models.
|
||||
Args:
|
||||
encoded_text: [B, T, D, L]
|
||||
attention_mask: [B, T] binary mask
|
||||
Returns:
|
||||
[B, T, D*L] normalized tensor with padding zeroed out.
|
||||
"""
|
||||
B, T, D, L = encoded_text.shape # noqa: N806
|
||||
variance = torch.mean(encoded_text**2, dim=2, keepdim=True) # [B,T,1,L]
|
||||
normed = encoded_text * torch.rsqrt(variance + 1e-6)
|
||||
normed = normed.reshape(B, T, D * L)
|
||||
mask_3d = attention_mask.bool().unsqueeze(-1) # [B, T, 1]
|
||||
return torch.where(mask_3d, normed, torch.zeros_like(normed))
|
||||
|
||||
|
||||
def _rescale_norm(x: torch.Tensor, target_dim: int, source_dim: int) -> torch.Tensor:
|
||||
"""Rescale normalization: x * sqrt(target_dim / source_dim)."""
|
||||
return x * math.sqrt(target_dim / source_dim)
|
||||
|
||||
Reference in New Issue
Block a user