mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
1305 lines
57 KiB
Python
1305 lines
57 KiB
Python
# original code from: comfy/ldm/cosmos/predict2.py
|
|
|
|
import torch
|
|
from torch import nn
|
|
from einops import rearrange, repeat
|
|
from einops.layers.torch import Rearrange
|
|
import logging
|
|
from typing import Callable, Optional, Tuple, List
|
|
import math
|
|
from torchvision import transforms
|
|
from ..core.attention import attention_forward
|
|
from ..core.gradient import gradient_checkpoint_forward
|
|
|
|
|
|
class VideoPositionEmb(nn.Module):
|
|
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
|
"""
|
|
It delegates the embedding generation to generate_embeddings function.
|
|
"""
|
|
B_T_H_W_C = x_B_T_H_W_C.shape
|
|
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
|
|
|
|
return embeddings
|
|
|
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
|
|
raise NotImplementedError
|
|
|
|
|
|
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
|
|
"""
|
|
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input tensor to normalize.
|
|
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
|
|
eps (float, optional): A small constant to ensure numerical stability during division.
|
|
|
|
Returns:
|
|
torch.Tensor: The normalized tensor.
|
|
"""
|
|
if dim is None:
|
|
dim = list(range(1, x.ndim))
|
|
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
|
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
|
return x / norm.to(x.dtype)
|
|
|
|
|
|
class LearnablePosEmbAxis(VideoPositionEmb):
|
|
def __init__(
|
|
self,
|
|
*, # enforce keyword arguments
|
|
interpolation: str,
|
|
model_channels: int,
|
|
len_h: int,
|
|
len_w: int,
|
|
len_t: int,
|
|
device=None,
|
|
dtype=None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Args:
|
|
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
|
|
"""
|
|
del kwargs # unused
|
|
super().__init__()
|
|
self.interpolation = interpolation
|
|
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
|
|
|
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
|
|
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
|
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
|
|
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
|
B, T, H, W, _ = B_T_H_W_C
|
|
if self.interpolation == "crop":
|
|
emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
|
|
emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
|
|
emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
|
|
emb = (
|
|
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
|
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
|
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
|
|
)
|
|
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
|
|
else:
|
|
raise ValueError(f"Unknown interpolation method {self.interpolation}")
|
|
|
|
return normalize(emb, dim=-1, eps=1e-6)
|
|
|
|
|
|
class VideoRopePosition3DEmb(VideoPositionEmb):
|
|
def __init__(
|
|
self,
|
|
*, # enforce keyword arguments
|
|
head_dim: int,
|
|
len_h: int,
|
|
len_w: int,
|
|
len_t: int,
|
|
base_fps: int = 24,
|
|
h_extrapolation_ratio: float = 1.0,
|
|
w_extrapolation_ratio: float = 1.0,
|
|
t_extrapolation_ratio: float = 1.0,
|
|
enable_fps_modulation: bool = True,
|
|
device=None,
|
|
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
|
):
|
|
del kwargs
|
|
super().__init__()
|
|
self.base_fps = base_fps
|
|
self.max_h = len_h
|
|
self.max_w = len_w
|
|
self.enable_fps_modulation = enable_fps_modulation
|
|
|
|
dim = head_dim
|
|
dim_h = dim // 6 * 2
|
|
dim_w = dim_h
|
|
dim_t = dim - 2 * dim_h
|
|
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
|
|
self.register_buffer(
|
|
"dim_spatial_range",
|
|
torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
|
|
persistent=False,
|
|
)
|
|
self.register_buffer(
|
|
"dim_temporal_range",
|
|
torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
|
|
persistent=False,
|
|
)
|
|
|
|
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
|
|
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
|
|
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
|
|
|
|
def generate_embeddings(
|
|
self,
|
|
B_T_H_W_C: torch.Size,
|
|
fps: Optional[torch.Tensor] = None,
|
|
h_ntk_factor: Optional[float] = None,
|
|
w_ntk_factor: Optional[float] = None,
|
|
t_ntk_factor: Optional[float] = None,
|
|
device=None,
|
|
dtype=None,
|
|
):
|
|
"""
|
|
Generate embeddings for the given input size.
|
|
|
|
Args:
|
|
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
|
|
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
|
|
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
|
|
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
|
|
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
|
|
|
|
Returns:
|
|
Not specified in the original code snippet.
|
|
"""
|
|
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
|
|
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
|
|
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
|
|
|
|
h_theta = 10000.0 * h_ntk_factor
|
|
w_theta = 10000.0 * w_ntk_factor
|
|
t_theta = 10000.0 * t_ntk_factor
|
|
|
|
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
|
|
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
|
|
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
|
|
|
B, T, H, W, _ = B_T_H_W_C
|
|
seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
|
|
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
|
assert (
|
|
uniform_fps or B == 1 or T == 1
|
|
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
|
half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
|
|
half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
|
|
|
|
# apply sequence scaling in temporal dimension
|
|
if fps is None or self.enable_fps_modulation is False: # image case
|
|
half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
|
|
else:
|
|
half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
|
|
|
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
|
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
|
half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
|
|
|
|
em_T_H_W_D = torch.cat(
|
|
[
|
|
repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
|
|
repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
|
|
repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
|
|
]
|
|
, dim=-2,
|
|
)
|
|
|
|
return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
|
|
|
|
|
|
def apply_rotary_pos_emb(
|
|
t: torch.Tensor,
|
|
freqs: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
|
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
|
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
|
return t_out
|
|
|
|
|
|
# ---------------------- Feed Forward Network -----------------------
|
|
class GPT2FeedForward(nn.Module):
|
|
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
|
|
super().__init__()
|
|
self.activation = nn.GELU()
|
|
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
|
|
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
|
|
|
|
self._layer_id = None
|
|
self._dim = d_model
|
|
self._hidden_dim = d_ff
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.layer1(x)
|
|
|
|
x = self.activation(x)
|
|
x = self.layer2(x)
|
|
return x
|
|
|
|
|
|
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
|
"""Computes multi-head attention using PyTorch's native implementation.
|
|
|
|
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
|
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
|
|
attention, and rearranges the output back to the original format.
|
|
|
|
The input tensor names use the following dimension conventions:
|
|
|
|
- B: batch size
|
|
- S: sequence length
|
|
- H: number of attention heads
|
|
- D: head dimension
|
|
|
|
Args:
|
|
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
|
|
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
|
|
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
|
|
|
|
Returns:
|
|
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
|
|
"""
|
|
in_q_shape = q_B_S_H_D.shape
|
|
in_k_shape = k_B_S_H_D.shape
|
|
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
|
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
|
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
|
return attention_forward(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, out_pattern="b s (n d)")
|
|
|
|
|
|
class Attention(nn.Module):
|
|
"""
|
|
A flexible attention module supporting both self-attention and cross-attention mechanisms.
|
|
|
|
This module implements a multi-head attention layer that can operate in either self-attention
|
|
or cross-attention mode. The mode is determined by whether a context dimension is provided.
|
|
The implementation uses scaled dot-product attention and supports optional bias terms and
|
|
dropout regularization.
|
|
|
|
Args:
|
|
query_dim (int): The dimensionality of the query vectors.
|
|
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
|
|
If None, the module operates in self-attention mode using query_dim. Default: None
|
|
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
|
|
head_dim (int, optional): The dimension of each attention head. Default: 64
|
|
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
|
|
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
|
|
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
|
|
|
|
Examples:
|
|
>>> # Self-attention with 512 dimensions and 8 heads
|
|
>>> self_attn = Attention(query_dim=512)
|
|
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
|
|
>>> out = self_attn(x) # (32, 16, 512)
|
|
|
|
>>> # Cross-attention
|
|
>>> cross_attn = Attention(query_dim=512, context_dim=256)
|
|
>>> query = torch.randn(32, 16, 512)
|
|
>>> context = torch.randn(32, 8, 256)
|
|
>>> out = cross_attn(query, context) # (32, 16, 512)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
query_dim: int,
|
|
context_dim: Optional[int] = None,
|
|
n_heads: int = 8,
|
|
head_dim: int = 64,
|
|
dropout: float = 0.0,
|
|
device=None,
|
|
dtype=None,
|
|
operations=None,
|
|
) -> None:
|
|
super().__init__()
|
|
logging.debug(
|
|
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
|
f"{n_heads} heads with a dimension of {head_dim}."
|
|
)
|
|
self.is_selfattn = context_dim is None # self attention
|
|
|
|
context_dim = query_dim if context_dim is None else context_dim
|
|
inner_dim = head_dim * n_heads
|
|
|
|
self.n_heads = n_heads
|
|
self.head_dim = head_dim
|
|
self.query_dim = query_dim
|
|
self.context_dim = context_dim
|
|
|
|
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
|
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
|
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
|
self.v_norm = nn.Identity()
|
|
|
|
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
|
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
|
|
|
|
self.attn_op = torch_attention_op
|
|
|
|
self._query_dim = query_dim
|
|
self._context_dim = context_dim
|
|
self._inner_dim = inner_dim
|
|
|
|
def compute_qkv(
|
|
self,
|
|
x: torch.Tensor,
|
|
context: Optional[torch.Tensor] = None,
|
|
rope_emb: Optional[torch.Tensor] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
q = self.q_proj(x)
|
|
context = x if context is None else context
|
|
k = self.k_proj(context)
|
|
v = self.v_proj(context)
|
|
q, k, v = map(
|
|
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
|
|
(q, k, v),
|
|
)
|
|
|
|
def apply_norm_and_rotary_pos_emb(
|
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
q = self.q_norm(q)
|
|
k = self.k_norm(k)
|
|
v = self.v_norm(v)
|
|
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
|
q = apply_rotary_pos_emb(q, rope_emb)
|
|
k = apply_rotary_pos_emb(k, rope_emb)
|
|
return q, k, v
|
|
|
|
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
|
|
|
return q, k, v
|
|
|
|
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
|
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
|
|
return self.output_dropout(self.output_proj(result))
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
context: Optional[torch.Tensor] = None,
|
|
rope_emb: Optional[torch.Tensor] = None,
|
|
transformer_options: Optional[dict] = {},
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
x (Tensor): The query tensor of shape [B, Mq, K]
|
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
|
"""
|
|
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
|
return self.compute_attention(q, k, v, transformer_options=transformer_options)
|
|
|
|
|
|
class Timesteps(nn.Module):
|
|
def __init__(self, num_channels: int):
|
|
super().__init__()
|
|
self.num_channels = num_channels
|
|
|
|
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
|
|
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
|
|
timesteps = timesteps_B_T.flatten().float()
|
|
half_dim = self.num_channels // 2
|
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
|
exponent = exponent / (half_dim - 0.0)
|
|
|
|
emb = torch.exp(exponent)
|
|
emb = timesteps[:, None].float() * emb[None, :]
|
|
|
|
sin_emb = torch.sin(emb)
|
|
cos_emb = torch.cos(emb)
|
|
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
|
|
|
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
|
|
|
|
|
|
class TimestepEmbedding(nn.Module):
|
|
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
logging.debug(
|
|
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
|
)
|
|
self.in_dim = in_features
|
|
self.out_dim = out_features
|
|
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
|
|
self.activation = nn.SiLU()
|
|
self.use_adaln_lora = use_adaln_lora
|
|
if use_adaln_lora:
|
|
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
|
|
else:
|
|
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
|
|
|
|
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
emb = self.linear_1(sample)
|
|
emb = self.activation(emb)
|
|
emb = self.linear_2(emb)
|
|
|
|
if self.use_adaln_lora:
|
|
adaln_lora_B_T_3D = emb
|
|
emb_B_T_D = sample
|
|
else:
|
|
adaln_lora_B_T_3D = None
|
|
emb_B_T_D = emb
|
|
|
|
return emb_B_T_D, adaln_lora_B_T_3D
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
"""
|
|
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
|
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
|
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
|
and embedding each patch into a vector of size `out_channels`.
|
|
|
|
Parameters:
|
|
- spatial_patch_size (int): The size of each spatial patch.
|
|
- temporal_patch_size (int): The size of each temporal patch.
|
|
- in_channels (int): Number of input channels. Default: 3.
|
|
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
|
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
spatial_patch_size: int,
|
|
temporal_patch_size: int,
|
|
in_channels: int = 3,
|
|
out_channels: int = 768,
|
|
device=None, dtype=None, operations=None
|
|
):
|
|
super().__init__()
|
|
self.spatial_patch_size = spatial_patch_size
|
|
self.temporal_patch_size = temporal_patch_size
|
|
|
|
self.proj = nn.Sequential(
|
|
Rearrange(
|
|
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
|
r=temporal_patch_size,
|
|
m=spatial_patch_size,
|
|
n=spatial_patch_size,
|
|
),
|
|
operations.Linear(
|
|
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
|
|
),
|
|
)
|
|
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Forward pass of the PatchEmbed module.
|
|
|
|
Parameters:
|
|
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
|
B is the batch size,
|
|
C is the number of channels,
|
|
T is the temporal dimension,
|
|
H is the height, and
|
|
W is the width of the input.
|
|
|
|
Returns:
|
|
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
|
"""
|
|
assert x.dim() == 5
|
|
_, _, T, H, W = x.shape
|
|
assert (
|
|
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
|
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
|
|
assert T % self.temporal_patch_size == 0
|
|
x = self.proj(x)
|
|
return x
|
|
|
|
|
|
class FinalLayer(nn.Module):
|
|
"""
|
|
The final layer of video DiT.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
spatial_patch_size: int,
|
|
temporal_patch_size: int,
|
|
out_channels: int,
|
|
use_adaln_lora: bool = False,
|
|
adaln_lora_dim: int = 256,
|
|
device=None, dtype=None, operations=None
|
|
):
|
|
super().__init__()
|
|
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
self.linear = operations.Linear(
|
|
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
|
)
|
|
self.hidden_size = hidden_size
|
|
self.n_adaln_chunks = 2
|
|
self.use_adaln_lora = use_adaln_lora
|
|
self.adaln_lora_dim = adaln_lora_dim
|
|
if use_adaln_lora:
|
|
self.adaln_modulation = nn.Sequential(
|
|
nn.SiLU(),
|
|
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
|
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
|
|
)
|
|
else:
|
|
self.adaln_modulation = nn.Sequential(
|
|
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x_B_T_H_W_D: torch.Tensor,
|
|
emb_B_T_D: torch.Tensor,
|
|
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
|
):
|
|
if self.use_adaln_lora:
|
|
assert adaln_lora_B_T_3D is not None
|
|
shift_B_T_D, scale_B_T_D = (
|
|
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
|
|
).chunk(2, dim=-1)
|
|
else:
|
|
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
|
|
|
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
|
|
scale_B_T_D, "b t d -> b t 1 1 d"
|
|
)
|
|
|
|
def _fn(
|
|
_x_B_T_H_W_D: torch.Tensor,
|
|
_norm_layer: nn.Module,
|
|
_scale_B_T_1_1_D: torch.Tensor,
|
|
_shift_B_T_1_1_D: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
|
|
|
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
|
|
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
|
|
return x_B_T_H_W_O
|
|
|
|
|
|
class Block(nn.Module):
|
|
"""
|
|
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
|
|
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
|
|
|
|
Parameters:
|
|
x_dim (int): Dimension of input features
|
|
context_dim (int): Dimension of context features for cross-attention
|
|
num_heads (int): Number of attention heads
|
|
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
|
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
|
|
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
|
|
|
|
The block applies the following sequence:
|
|
1. Self-attention with AdaLN modulation
|
|
2. Cross-attention with AdaLN modulation
|
|
3. MLP with AdaLN modulation
|
|
|
|
Each component uses skip connections and layer normalization.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
x_dim: int,
|
|
context_dim: int,
|
|
num_heads: int,
|
|
mlp_ratio: float = 4.0,
|
|
use_adaln_lora: bool = False,
|
|
adaln_lora_dim: int = 256,
|
|
device=None,
|
|
dtype=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
self.x_dim = x_dim
|
|
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
|
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
|
|
|
|
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
|
self.cross_attn = Attention(
|
|
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
|
|
)
|
|
|
|
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
|
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
|
|
|
self.use_adaln_lora = use_adaln_lora
|
|
if self.use_adaln_lora:
|
|
self.adaln_modulation_self_attn = nn.Sequential(
|
|
nn.SiLU(),
|
|
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
|
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
|
)
|
|
self.adaln_modulation_cross_attn = nn.Sequential(
|
|
nn.SiLU(),
|
|
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
|
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
|
)
|
|
self.adaln_modulation_mlp = nn.Sequential(
|
|
nn.SiLU(),
|
|
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
|
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
|
)
|
|
else:
|
|
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
|
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
|
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
|
|
|
def forward(
|
|
self,
|
|
x_B_T_H_W_D: torch.Tensor,
|
|
emb_B_T_D: torch.Tensor,
|
|
crossattn_emb: torch.Tensor,
|
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
|
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
|
transformer_options: Optional[dict] = {},
|
|
) -> torch.Tensor:
|
|
residual_dtype = x_B_T_H_W_D.dtype
|
|
compute_dtype = emb_B_T_D.dtype
|
|
if extra_per_block_pos_emb is not None:
|
|
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
|
|
|
if self.use_adaln_lora:
|
|
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
|
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
|
).chunk(3, dim=-1)
|
|
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
|
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
|
).chunk(3, dim=-1)
|
|
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
|
|
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
|
|
).chunk(3, dim=-1)
|
|
else:
|
|
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
|
|
emb_B_T_D
|
|
).chunk(3, dim=-1)
|
|
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
|
emb_B_T_D
|
|
).chunk(3, dim=-1)
|
|
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
|
|
|
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
|
|
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
|
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
|
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
|
|
|
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
|
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
|
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
|
|
|
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
|
|
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
|
|
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
|
|
|
|
B, T, H, W, D = x_B_T_H_W_D.shape
|
|
|
|
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
|
|
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
|
|
|
normalized_x_B_T_H_W_D = _fn(
|
|
x_B_T_H_W_D,
|
|
self.layer_norm_self_attn,
|
|
scale_self_attn_B_T_1_1_D,
|
|
shift_self_attn_B_T_1_1_D,
|
|
)
|
|
result_B_T_H_W_D = rearrange(
|
|
self.self_attn(
|
|
# normalized_x_B_T_HW_D,
|
|
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
|
None,
|
|
rope_emb=rope_emb_L_1_1_D,
|
|
transformer_options=transformer_options,
|
|
),
|
|
"b (t h w) d -> b t h w d",
|
|
t=T,
|
|
h=H,
|
|
w=W,
|
|
)
|
|
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
|
|
|
def _x_fn(
|
|
_x_B_T_H_W_D: torch.Tensor,
|
|
layer_norm_cross_attn: Callable,
|
|
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
|
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
|
transformer_options: Optional[dict] = {},
|
|
) -> torch.Tensor:
|
|
_normalized_x_B_T_H_W_D = _fn(
|
|
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
|
)
|
|
_result_B_T_H_W_D = rearrange(
|
|
self.cross_attn(
|
|
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
|
crossattn_emb,
|
|
rope_emb=rope_emb_L_1_1_D,
|
|
transformer_options=transformer_options,
|
|
),
|
|
"b (t h w) d -> b t h w d",
|
|
t=T,
|
|
h=H,
|
|
w=W,
|
|
)
|
|
return _result_B_T_H_W_D
|
|
|
|
result_B_T_H_W_D = _x_fn(
|
|
x_B_T_H_W_D,
|
|
self.layer_norm_cross_attn,
|
|
scale_cross_attn_B_T_1_1_D,
|
|
shift_cross_attn_B_T_1_1_D,
|
|
transformer_options=transformer_options,
|
|
)
|
|
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
|
|
|
normalized_x_B_T_H_W_D = _fn(
|
|
x_B_T_H_W_D,
|
|
self.layer_norm_mlp,
|
|
scale_mlp_B_T_1_1_D,
|
|
shift_mlp_B_T_1_1_D,
|
|
)
|
|
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
|
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
|
return x_B_T_H_W_D
|
|
|
|
|
|
class MiniTrainDIT(nn.Module):
|
|
"""
|
|
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
|
|
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
|
|
|
Args:
|
|
max_img_h (int): Maximum height of the input images.
|
|
max_img_w (int): Maximum width of the input images.
|
|
max_frames (int): Maximum number of frames in the video sequence.
|
|
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
|
out_channels (int): Number of output channels.
|
|
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
|
patch_temporal (int): Temporal resolution of patches for input processing.
|
|
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
|
model_channels (int): Base number of channels used throughout the model.
|
|
num_blocks (int): Number of transformer blocks.
|
|
num_heads (int): Number of heads in the multi-head attention layers.
|
|
mlp_ratio (float): Expansion ratio for MLP blocks.
|
|
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
|
pos_emb_cls (str): Type of positional embeddings.
|
|
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
|
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
|
min_fps (int): Minimum frames per second.
|
|
max_fps (int): Maximum frames per second.
|
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
|
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
|
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
|
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
|
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
|
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
|
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
|
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
|
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_img_h: int,
|
|
max_img_w: int,
|
|
max_frames: int,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
patch_spatial: int, # tuple,
|
|
patch_temporal: int,
|
|
concat_padding_mask: bool = True,
|
|
# attention settings
|
|
model_channels: int = 768,
|
|
num_blocks: int = 10,
|
|
num_heads: int = 16,
|
|
mlp_ratio: float = 4.0,
|
|
# cross attention settings
|
|
crossattn_emb_channels: int = 1024,
|
|
# positional embedding settings
|
|
pos_emb_cls: str = "sincos",
|
|
pos_emb_learnable: bool = False,
|
|
pos_emb_interpolation: str = "crop",
|
|
min_fps: int = 1,
|
|
max_fps: int = 30,
|
|
use_adaln_lora: bool = False,
|
|
adaln_lora_dim: int = 256,
|
|
rope_h_extrapolation_ratio: float = 1.0,
|
|
rope_w_extrapolation_ratio: float = 1.0,
|
|
rope_t_extrapolation_ratio: float = 1.0,
|
|
extra_per_block_abs_pos_emb: bool = False,
|
|
extra_h_extrapolation_ratio: float = 1.0,
|
|
extra_w_extrapolation_ratio: float = 1.0,
|
|
extra_t_extrapolation_ratio: float = 1.0,
|
|
rope_enable_fps_modulation: bool = True,
|
|
image_model=None,
|
|
device=None,
|
|
dtype=None,
|
|
operations=None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
self.max_img_h = max_img_h
|
|
self.max_img_w = max_img_w
|
|
self.max_frames = max_frames
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.patch_spatial = patch_spatial
|
|
self.patch_temporal = patch_temporal
|
|
self.num_heads = num_heads
|
|
self.num_blocks = num_blocks
|
|
self.model_channels = model_channels
|
|
self.concat_padding_mask = concat_padding_mask
|
|
# positional embedding settings
|
|
self.pos_emb_cls = pos_emb_cls
|
|
self.pos_emb_learnable = pos_emb_learnable
|
|
self.pos_emb_interpolation = pos_emb_interpolation
|
|
self.min_fps = min_fps
|
|
self.max_fps = max_fps
|
|
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
|
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
|
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
|
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
|
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
|
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
|
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
|
self.rope_enable_fps_modulation = rope_enable_fps_modulation
|
|
|
|
self.build_pos_embed(device=device, dtype=dtype)
|
|
self.use_adaln_lora = use_adaln_lora
|
|
self.adaln_lora_dim = adaln_lora_dim
|
|
self.t_embedder = nn.Sequential(
|
|
Timesteps(model_channels),
|
|
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
|
|
)
|
|
|
|
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
|
self.x_embedder = PatchEmbed(
|
|
spatial_patch_size=patch_spatial,
|
|
temporal_patch_size=patch_temporal,
|
|
in_channels=in_channels,
|
|
out_channels=model_channels,
|
|
device=device, dtype=dtype, operations=operations,
|
|
)
|
|
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
Block(
|
|
x_dim=model_channels,
|
|
context_dim=crossattn_emb_channels,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
use_adaln_lora=use_adaln_lora,
|
|
adaln_lora_dim=adaln_lora_dim,
|
|
device=device, dtype=dtype, operations=operations,
|
|
)
|
|
for _ in range(num_blocks)
|
|
]
|
|
)
|
|
|
|
self.final_layer = FinalLayer(
|
|
hidden_size=self.model_channels,
|
|
spatial_patch_size=self.patch_spatial,
|
|
temporal_patch_size=self.patch_temporal,
|
|
out_channels=self.out_channels,
|
|
use_adaln_lora=self.use_adaln_lora,
|
|
adaln_lora_dim=self.adaln_lora_dim,
|
|
device=device, dtype=dtype, operations=operations,
|
|
)
|
|
|
|
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
def build_pos_embed(self, device=None, dtype=None) -> None:
|
|
if self.pos_emb_cls == "rope3d":
|
|
cls_type = VideoRopePosition3DEmb
|
|
else:
|
|
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
|
|
|
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
|
kwargs = dict(
|
|
model_channels=self.model_channels,
|
|
len_h=self.max_img_h // self.patch_spatial,
|
|
len_w=self.max_img_w // self.patch_spatial,
|
|
len_t=self.max_frames // self.patch_temporal,
|
|
max_fps=self.max_fps,
|
|
min_fps=self.min_fps,
|
|
is_learnable=self.pos_emb_learnable,
|
|
interpolation=self.pos_emb_interpolation,
|
|
head_dim=self.model_channels // self.num_heads,
|
|
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
|
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
|
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
|
enable_fps_modulation=self.rope_enable_fps_modulation,
|
|
device=device,
|
|
)
|
|
self.pos_embedder = cls_type(
|
|
**kwargs, # type: ignore
|
|
)
|
|
|
|
if self.extra_per_block_abs_pos_emb:
|
|
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
|
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
|
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
|
kwargs["device"] = device
|
|
kwargs["dtype"] = dtype
|
|
self.extra_pos_embedder = LearnablePosEmbAxis(
|
|
**kwargs, # type: ignore
|
|
)
|
|
|
|
def prepare_embedded_sequence(
|
|
self,
|
|
x_B_C_T_H_W: torch.Tensor,
|
|
fps: Optional[torch.Tensor] = None,
|
|
padding_mask: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
"""
|
|
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
|
|
|
Args:
|
|
x_B_C_T_H_W (torch.Tensor): video
|
|
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
|
If None, a default value (`self.base_fps`) will be used.
|
|
padding_mask (Optional[torch.Tensor]): current it is not used
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
|
- An optional positional embedding tensor, returned only if the positional embedding class
|
|
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
|
|
|
Notes:
|
|
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
|
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
|
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
|
the `self.pos_embedder` with the shape [T, H, W].
|
|
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
|
`self.pos_embedder` with the fps tensor.
|
|
- Otherwise, the positional embeddings are generated without considering fps.
|
|
"""
|
|
if self.concat_padding_mask:
|
|
if padding_mask is None:
|
|
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
|
else:
|
|
padding_mask = transforms.functional.resize(
|
|
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
|
)
|
|
x_B_C_T_H_W = torch.cat(
|
|
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
|
)
|
|
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
|
|
|
if self.extra_per_block_abs_pos_emb:
|
|
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
|
else:
|
|
extra_pos_emb = None
|
|
|
|
if "rope" in self.pos_emb_cls.lower():
|
|
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
|
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
|
|
|
return x_B_T_H_W_D, None, extra_pos_emb
|
|
|
|
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
|
|
x_B_C_Tt_Hp_Wp = rearrange(
|
|
x_B_T_H_W_M,
|
|
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
|
p1=self.patch_spatial,
|
|
p2=self.patch_spatial,
|
|
t=self.patch_temporal,
|
|
)
|
|
return x_B_C_Tt_Hp_Wp
|
|
|
|
def pad_to_patch_size(self, img, patch_size=(2, 2), padding_mode="circular"):
|
|
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
|
padding_mode = "reflect"
|
|
|
|
pad = ()
|
|
for i in range(img.ndim - 2):
|
|
pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
|
|
|
|
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
timesteps: torch.Tensor,
|
|
context: torch.Tensor,
|
|
fps: Optional[torch.Tensor] = None,
|
|
padding_mask: Optional[torch.Tensor] = None,
|
|
use_gradient_checkpointing=False,
|
|
use_gradient_checkpointing_offload=False,
|
|
**kwargs,
|
|
):
|
|
orig_shape = list(x.shape)
|
|
x = self.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
|
|
x_B_C_T_H_W = x
|
|
timesteps_B_T = timesteps
|
|
crossattn_emb = context
|
|
"""
|
|
Args:
|
|
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
|
timesteps: (B, ) tensor of timesteps
|
|
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
|
"""
|
|
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
|
x_B_C_T_H_W,
|
|
fps=fps,
|
|
padding_mask=padding_mask,
|
|
)
|
|
|
|
if timesteps_B_T.ndim == 1:
|
|
timesteps_B_T = timesteps_B_T.unsqueeze(1)
|
|
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
|
|
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
|
|
|
|
# for logging purpose
|
|
affline_scale_log_info = {}
|
|
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
|
|
self.affline_scale_log_info = affline_scale_log_info
|
|
self.affline_emb = t_embedding_B_T_D
|
|
self.crossattn_emb = crossattn_emb
|
|
|
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
|
assert (
|
|
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
|
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
|
|
|
|
block_kwargs = {
|
|
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
|
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
|
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
|
"transformer_options": kwargs.get("transformer_options", {}),
|
|
}
|
|
|
|
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
|
# in fp32, but run attention and MLP modules in fp16.
|
|
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
|
# quality degradation and visual artifacts.
|
|
if x_B_T_H_W_D.dtype == torch.float16:
|
|
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
|
|
|
for block in self.blocks:
|
|
x_B_T_H_W_D = gradient_checkpoint_forward(
|
|
block,
|
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
x_B_T_H_W_D=x_B_T_H_W_D,
|
|
emb_B_T_D=t_embedding_B_T_D,
|
|
crossattn_emb=crossattn_emb,
|
|
**block_kwargs,
|
|
)
|
|
|
|
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
|
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
|
return x_B_C_Tt_Hp_Wp
|
|
|
|
|
|
def rotate_half(x):
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
def apply_rotary_pos_emb2(x, cos, sin, unsqueeze_dim=1):
|
|
cos = cos.unsqueeze(unsqueeze_dim)
|
|
sin = sin.unsqueeze(unsqueeze_dim)
|
|
x_embed = (x * cos) + (rotate_half(x) * sin)
|
|
return x_embed
|
|
|
|
|
|
class RotaryEmbedding(nn.Module):
|
|
def __init__(self, head_dim):
|
|
super().__init__()
|
|
self.rope_theta = 10000
|
|
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
@torch.no_grad()
|
|
def forward(self, x, position_ids):
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
cos = emb.cos()
|
|
sin = emb.sin()
|
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
|
|
|
|
class LLMAdapterAttention(nn.Module):
|
|
def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
|
|
inner_dim = head_dim * n_heads
|
|
self.n_heads = n_heads
|
|
self.head_dim = head_dim
|
|
self.query_dim = query_dim
|
|
self.context_dim = context_dim
|
|
|
|
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
|
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
|
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
|
|
|
self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
|
|
|
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
|
|
context = x if context is None else context
|
|
input_shape = x.shape[:-1]
|
|
q_shape = (*input_shape, self.n_heads, self.head_dim)
|
|
context_shape = context.shape[:-1]
|
|
kv_shape = (*context_shape, self.n_heads, self.head_dim)
|
|
|
|
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
|
|
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
|
|
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
|
|
|
|
if position_embeddings is not None:
|
|
assert position_embeddings_context is not None
|
|
cos, sin = position_embeddings
|
|
query_states = apply_rotary_pos_emb2(query_states, cos, sin)
|
|
cos, sin = position_embeddings_context
|
|
key_states = apply_rotary_pos_emb2(key_states, cos, sin)
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
|
|
|
|
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output
|
|
|
|
def init_weights(self):
|
|
torch.nn.init.zeros_(self.o_proj.weight)
|
|
|
|
|
|
class LLMAdapterTransformerBlock(nn.Module):
|
|
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.use_self_attn = use_self_attn
|
|
|
|
if self.use_self_attn:
|
|
self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
|
self.self_attn = LLMAdapterAttention(
|
|
query_dim=model_dim,
|
|
context_dim=model_dim,
|
|
n_heads=num_heads,
|
|
head_dim=model_dim//num_heads,
|
|
device=device,
|
|
dtype=dtype,
|
|
operations=operations,
|
|
)
|
|
|
|
self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
|
self.cross_attn = LLMAdapterAttention(
|
|
query_dim=model_dim,
|
|
context_dim=source_dim,
|
|
n_heads=num_heads,
|
|
head_dim=model_dim//num_heads,
|
|
device=device,
|
|
dtype=dtype,
|
|
operations=operations,
|
|
)
|
|
|
|
self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
|
self.mlp = nn.Sequential(
|
|
operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),
|
|
nn.GELU(),
|
|
operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)
|
|
)
|
|
|
|
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):
|
|
if self.use_self_attn:
|
|
normed = self.norm_self_attn(x)
|
|
attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)
|
|
x = x + attn_out
|
|
|
|
normed = self.norm_cross_attn(x)
|
|
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
|
|
x = x + attn_out
|
|
|
|
x = x + self.mlp(self.norm_mlp(x))
|
|
return x
|
|
|
|
def init_weights(self):
|
|
torch.nn.init.zeros_(self.mlp[2].weight)
|
|
self.cross_attn.init_weights()
|
|
|
|
|
|
class LLMAdapter(nn.Module):
|
|
def __init__(
|
|
self,
|
|
source_dim=1024,
|
|
target_dim=1024,
|
|
model_dim=1024,
|
|
num_layers=6,
|
|
num_heads=16,
|
|
use_self_attn=True,
|
|
layer_norm=False,
|
|
device=None,
|
|
dtype=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)
|
|
if model_dim != target_dim:
|
|
self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)
|
|
else:
|
|
self.in_proj = nn.Identity()
|
|
self.rotary_emb = RotaryEmbedding(model_dim//num_heads)
|
|
self.blocks = nn.ModuleList([
|
|
LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)
|
|
])
|
|
self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)
|
|
self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
|
|
if target_attention_mask is not None:
|
|
target_attention_mask = target_attention_mask.to(torch.bool)
|
|
if target_attention_mask.ndim == 2:
|
|
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
|
|
|
|
if source_attention_mask is not None:
|
|
source_attention_mask = source_attention_mask.to(torch.bool)
|
|
if source_attention_mask.ndim == 2:
|
|
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
|
|
|
context = source_hidden_states
|
|
x = self.in_proj(self.embed(target_input_ids).to(context.dtype))
|
|
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
|
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
|
position_embeddings = self.rotary_emb(x, position_ids)
|
|
position_embeddings_context = self.rotary_emb(x, position_ids_context)
|
|
for block in self.blocks:
|
|
x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
|
|
return self.norm(self.out_proj(x))
|
|
|
|
|
|
class AnimaDiT(MiniTrainDIT):
|
|
def __init__(self):
|
|
kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn}
|
|
super().__init__(**kwargs)
|
|
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
|
|
|
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
|
if text_ids is not None:
|
|
out = self.llm_adapter(text_embeds, text_ids)
|
|
if t5xxl_weights is not None:
|
|
out = out * t5xxl_weights
|
|
|
|
if out.shape[1] < 512:
|
|
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
|
return out
|
|
else:
|
|
return text_embeds
|
|
|
|
def forward(
|
|
self,
|
|
x, timesteps, context,
|
|
use_gradient_checkpointing=False,
|
|
use_gradient_checkpointing_offload=False,
|
|
**kwargs
|
|
):
|
|
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
|
if t5xxl_ids is not None:
|
|
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
|
return super().forward(
|
|
x, timesteps, context,
|
|
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
**kwargs
|
|
)
|