mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
hunyuanvideo pipeline
This commit is contained in:
@@ -3,6 +3,193 @@ from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||
from .utils import init_weights_on_device
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Tuple, List
|
||||
|
||||
|
||||
def HunyuanVideoRope(latents):
|
||||
def _to_tuple(x, dim=2):
|
||||
if isinstance(x, int):
|
||||
return (x,) * dim
|
||||
elif len(x) == dim:
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||
|
||||
|
||||
def get_meshgrid_nd(start, *args, dim=2):
|
||||
"""
|
||||
Get n-D meshgrid with start, stop and num.
|
||||
|
||||
Args:
|
||||
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
||||
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
||||
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
||||
n-tuples.
|
||||
*args: See above.
|
||||
dim (int): Dimension of the meshgrid. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
grid (np.ndarray): [dim, ...]
|
||||
"""
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = _to_tuple(start, dim=dim)
|
||||
start = (0,) * dim
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = _to_tuple(start, dim=dim)
|
||||
stop = _to_tuple(args[0], dim=dim)
|
||||
num = [stop[i] - start[i] for i in range(dim)]
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
||||
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
||||
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
||||
axis_grid = []
|
||||
for i in range(dim):
|
||||
a, b, n = start[i], stop[i], num[i]
|
||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||
axis_grid.append(g)
|
||||
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
||||
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[torch.FloatTensor, int],
|
||||
theta: float = 10000.0,
|
||||
use_real: bool = False,
|
||||
theta_rescale_factor: float = 1.0,
|
||||
interpolation_factor: float = 1.0,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
||||
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
||||
|
||||
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
||||
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||
The returned tensor contains complex values in complex64 data type.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the frequency tensor.
|
||||
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool, optional): If True, return real part and imaginary part separately.
|
||||
Otherwise, return complex numbers.
|
||||
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
||||
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
||||
"""
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos).float()
|
||||
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
if theta_rescale_factor != 1.0:
|
||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
freqs = 1.0 / (
|
||||
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
||||
) # [D/2]
|
||||
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
||||
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(
|
||||
torch.ones_like(freqs), freqs
|
||||
) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def get_nd_rotary_pos_embed(
|
||||
rope_dim_list,
|
||||
start,
|
||||
*args,
|
||||
theta=10000.0,
|
||||
use_real=False,
|
||||
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
||||
interpolation_factor: Union[float, List[float]] = 1.0,
|
||||
):
|
||||
"""
|
||||
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
||||
|
||||
Args:
|
||||
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
||||
sum(rope_dim_list) should equal to head_dim of attention layer.
|
||||
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
||||
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
||||
*args: See above.
|
||||
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
||||
part and an imaginary part separately.
|
||||
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
pos_embed (torch.Tensor): [HW, D/2]
|
||||
"""
|
||||
|
||||
grid = get_meshgrid_nd(
|
||||
start, *args, dim=len(rope_dim_list)
|
||||
) # [3, W, H, D] / [2, W, H]
|
||||
|
||||
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
||||
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
||||
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
||||
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
||||
assert len(theta_rescale_factor) == len(
|
||||
rope_dim_list
|
||||
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
||||
|
||||
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
||||
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
||||
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
||||
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
||||
assert len(interpolation_factor) == len(
|
||||
rope_dim_list
|
||||
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
||||
|
||||
# use 1/ndim of dimensions to encode grid_axis
|
||||
embs = []
|
||||
for i in range(len(rope_dim_list)):
|
||||
emb = get_1d_rotary_pos_embed(
|
||||
rope_dim_list[i],
|
||||
grid[i].reshape(-1),
|
||||
theta,
|
||||
use_real=use_real,
|
||||
theta_rescale_factor=theta_rescale_factor[i],
|
||||
interpolation_factor=interpolation_factor[i],
|
||||
) # 2 x [WHD, rope_dim_list[i]]
|
||||
embs.append(emb)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
||||
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
||||
return emb
|
||||
|
||||
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
||||
[16, 56, 56],
|
||||
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
|
||||
theta=256,
|
||||
use_real=True,
|
||||
theta_rescale_factor=1,
|
||||
)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
@@ -406,13 +593,16 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
model.to(device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def prepare_freqs(self, latents):
|
||||
return HunyuanVideoRope(latents)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
text_states: torch.Tensor = None,
|
||||
prompt_emb: torch.Tensor = None,
|
||||
text_mask: torch.Tensor = None,
|
||||
text_states_2: torch.Tensor = None,
|
||||
pooled_prompt_emb: torch.Tensor = None,
|
||||
freqs_cos: torch.Tensor = None,
|
||||
freqs_sin: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
@@ -420,9 +610,9 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(text_states_2) + self.guidance_in(guidance, dtype=torch.float32)
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
img = self.img_in(x)
|
||||
txt = self.txt_in(text_states, t, text_mask)
|
||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
@@ -2,15 +2,17 @@ import torch
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sdxl_text_encoder import SDXLTextEncoder2, SDXLTextEncoder2StateDictConverter
|
||||
|
||||
|
||||
|
||||
class SD3TextEncoder1(SDTextEncoder):
|
||||
def __init__(self, vocab_size=49408):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
|
||||
def forward(self, input_ids, clip_skip=2):
|
||||
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
if extra_mask is not None:
|
||||
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
|
||||
Reference in New Issue
Block a user