Merge pull request #290 from modelscope/dev

Dev
This commit is contained in:
Zhongjie Duan
2024-12-19 13:16:55 +08:00
committed by GitHub
27 changed files with 1353673 additions and 13 deletions

View File

@@ -43,9 +43,13 @@ from ..models.cog_dit import CogDiT
from ..models.omnigen import OmniGenTransformer
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
from ..models.hunyuan_video_dit import HunyuanVideoDiT
model_loader_configs = [
@@ -93,6 +97,10 @@ model_loader_configs = [
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -101,10 +109,11 @@ huggingface_model_loader_configs = [
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel")
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder")
]
patch_model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -627,6 +636,25 @@ preset_models_on_modelscope = {
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
],
"HunyuanVideo":{
"file_list": [
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
],
"load_path": [
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideo/text_encoder_2",
"models/HunyuanVideo/vae/pytorch_model.pt",
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
],
},
}
Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
@@ -682,4 +710,5 @@ Preset_model_id: TypeAlias = Literal[
"Annotators:Openpose",
"StableDiffusion3.5-large",
"StableDiffusion3.5-medium",
"HunyuanVideo",
]

View File

@@ -0,0 +1,885 @@
import torch
from .sd3_dit import TimestepEmbeddings, RMSNorm
from .utils import init_weights_on_device
from einops import rearrange, repeat
from tqdm import tqdm
from typing import Union, Tuple, List
def HunyuanVideoRope(latents):
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
use_real: bool = False,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(
torch.ones_like(freqs), freqs
) # complex64 # [S, D/2]
return freqs_cis
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(
start, *args, dim=len(rope_dim_list)
) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(
rope_dim_list
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(
rope_dim_list
), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
[16, 56, 56],
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
theta=256,
use_real=True,
theta_rescale_factor=1,
)
return freqs_cos, freqs_sin
class PatchEmbed(torch.nn.Module):
def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072):
super().__init__()
self.proj = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class IndividualTokenRefinerBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, num_heads=24):
super().__init__()
self.num_heads = num_heads
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.self_attn_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
self.self_attn_proj = torch.nn.Linear(hidden_size, hidden_size)
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.mlp = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size * 4),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size * 4, hidden_size)
)
self.adaLN_modulation = torch.nn.Sequential(
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size * 2, device="cuda", dtype=torch.bfloat16),
)
def forward(self, x, c, attn_mask=None):
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn_qkv(norm_x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn = rearrange(attn, "B H L D -> B L (H D)")
x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
return x
class SingleTokenRefiner(torch.nn.Module):
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
super().__init__()
self.input_embedder = torch.nn.Linear(in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
self.c_embedder = torch.nn.Sequential(
torch.nn.Linear(in_channels, hidden_size),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size)
)
self.blocks = torch.nn.ModuleList([IndividualTokenRefinerBlock(hidden_size=hidden_size) for _ in range(depth)])
def forward(self, x, t, mask=None):
timestep_aware_representations = self.t_embedder(t, dtype=torch.float32)
mask_float = mask.float().unsqueeze(-1)
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
x = self.input_embedder(x)
mask = mask.to(device=x.device, dtype=torch.bool)
mask = repeat(mask, "B L -> B 1 D L", D=mask.shape[-1])
mask = mask & mask.transpose(2, 3)
mask[:, :, :, 0] = True
for block in self.blocks:
x = block(x, c, mask)
return x
class ModulateDiT(torch.nn.Module):
def __init__(self, hidden_size, factor=6):
super().__init__()
self.act = torch.nn.SiLU()
self.linear = torch.nn.Linear(hidden_size, factor * hidden_size)
def forward(self, x):
return self.linear(self.act(x))
def modulate(x, shift=None, scale=None):
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def reshape_for_broadcast(
freqs_cis,
x: torch.Tensor,
head_first=False,
):
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
# freqs_cis: (cos, sin) in real space
if head_first:
assert freqs_cis[0].shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [
d if i == ndim - 2 or i == ndim - 1 else 1
for i, d in enumerate(x.shape)
]
else:
assert freqs_cis[0].shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
# freqs_cis: values in complex space
if head_first:
assert freqs_cis.shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [
d if i == ndim - 2 or i == ndim - 1 else 1
for i, d in enumerate(x.shape)
]
else:
assert freqs_cis.shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = (
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis,
head_first: bool = False,
):
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
# real * cos - imag * sin
# imag * cos + real * sin
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
else:
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], -1, 2)
) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
xq.device
) # [S, D//2] --> [1, S, 1, D//2]
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], -1, 2)
) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
def attention(q, k, v):
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = x.transpose(1, 2).flatten(2, 3)
return x
class MMDoubleStreamBlockComponent(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
self.heads_num = heads_num
self.mod = ModulateDiT(hidden_size)
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
)
def forward(self, hidden_states, conditioning, freqs_cis=None):
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
q = self.norm_q(q)
k = self.norm_k(k)
if freqs_cis is not None:
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
def process_ff(self, hidden_states, attn_output, mod):
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
return hidden_states
class MMDoubleStreamBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b)
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
return hidden_states_a, hidden_states_b
class MMSingleStreamBlockOriginal(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
self.hidden_size = hidden_size
self.heads_num = heads_num
self.mlp_hidden_dim = hidden_size * mlp_width_ratio
self.linear1 = torch.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
self.linear2 = torch.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.q_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.k_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.pre_norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = torch.nn.GELU(approximate="tanh")
self.modulation = ModulateDiT(hidden_size, factor=3)
def forward(self, x, vec, freqs_cis=None, txt_len=256):
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
q = self.q_norm(q)
k = self.k_norm(k)
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
q = torch.cat((q_a, q_b), dim=1)
k = torch.cat((k_a, k_b), dim=1)
attn_output_a = attention(q[:, :-185].contiguous(), k[:, :-185].contiguous(), v[:, :-185].contiguous())
attn_output_b = attention(q[:, -185:].contiguous(), k[:, -185:].contiguous(), v[:, -185:].contiguous())
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
return x + output * mod_gate.unsqueeze(1)
class MMSingleStreamBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
self.heads_num = heads_num
self.mod = ModulateDiT(hidden_size, factor=3)
self.norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
self.ff = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
)
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
norm_hidden_states = self.norm(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
q = self.norm_q(q)
k = self.norm_k(k)
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b)
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
return hidden_states
class FinalLayer(torch.nn.Module):
def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channels=16):
super().__init__()
self.norm_final = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = torch.nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels)
self.adaLN_modulation = torch.nn.Sequential(torch.nn.SiLU(), torch.nn.Linear(hidden_size, 2 * hidden_size))
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift=shift, scale=scale)
x = self.linear(x)
return x
class HunyuanVideoDiT(torch.nn.Module):
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
super().__init__()
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
self.time_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
self.vector_in = torch.nn.Sequential(
torch.nn.Linear(768, hidden_size),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size)
)
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
self.final_layer = FinalLayer(hidden_size)
# TODO: remove these parameters
self.dtype = torch.bfloat16
self.patch_size = [1, 2, 2]
self.hidden_size = 3072
self.heads_num = 24
self.rope_dim_list = [16, 56, 56]
def unpatchify(self, x, T, H, W):
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
return x
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
self.warm_device = warm_device
self.cold_device = cold_device
self.to(self.cold_device)
def load_models_to_device(self, loadmodel_names=[], device="cpu"):
for model_name in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
model.to(device)
torch.cuda.empty_cache()
def prepare_freqs(self, latents):
return HunyuanVideoRope(latents)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
**kwargs
):
B, C, T, H, W = x.shape
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
img = self.img_in(x)
txt = self.txt_in(prompt_emb, t, text_mask)
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
x = torch.concat([img, txt], dim=1)
for block in tqdm(self.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin))
img = x[:, :-256]
img = self.final_layer(img, vec)
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
def cast_to(weight, dtype=None, device=None, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def cast_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
return weight
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
return weight, bias
class quantized_layer:
class Linear(torch.nn.Linear):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def block_forward_(self, x, i, j, dtype, device):
weight_ = cast_to(
self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
dtype=dtype, device=device
)
if self.bias is None or i > 0:
bias_ = None
else:
bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
y_ = torch.nn.functional.linear(x_, weight_, bias_)
del x_, weight_, bias_
torch.cuda.empty_cache()
return y_
def block_forward(self, x, **kwargs):
# This feature can only reduce 2GB VRAM, so we disable it.
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
for i in range((self.in_features + self.block_size - 1) // self.block_size):
for j in range((self.out_features + self.block_size - 1) // self.block_size):
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
return y
def forward(self, x, **kwargs):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.linear(x, weight, bias)
class RMSNorm(torch.nn.Module):
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
super().__init__()
self.module = module
self.dtype = dtype
self.device = device
def forward(self, hidden_states, **kwargs):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype)
if self.module.weight is not None:
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
hidden_states = hidden_states * weight
return hidden_states
class Conv3d(torch.nn.Conv3d):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def forward(self, x):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class LayerNorm(torch.nn.LayerNorm):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def forward(self, x):
if self.weight is not None and self.bias is not None:
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
else:
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
with init_weights_on_device():
new_layer = quantized_layer.Linear(
module.in_features, module.out_features, bias=module.bias is not None,
dtype=dtype, device=device
)
new_layer.load_state_dict(module.state_dict(), assign=True)
setattr(model, name, new_layer)
elif isinstance(module, torch.nn.Conv3d):
with init_weights_on_device():
new_layer = quantized_layer.Conv3d(
module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
dtype=dtype, device=device
)
new_layer.load_state_dict(module.state_dict(), assign=True)
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
new_layer = quantized_layer.RMSNorm(
module,
dtype=dtype, device=device
)
setattr(model, name, new_layer)
elif isinstance(module, torch.nn.LayerNorm):
with init_weights_on_device():
new_layer = quantized_layer.LayerNorm(
module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
dtype=dtype, device=device
)
new_layer.load_state_dict(module.state_dict(), assign=True)
setattr(model, name, new_layer)
else:
replace_layer(module, dtype=dtype, device=device)
replace_layer(self, dtype=dtype, device=device)
@staticmethod
def state_dict_converter():
return HunyuanVideoDiTStateDictConverter()
class HunyuanVideoDiTStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
if "module" in state_dict:
state_dict = state_dict["module"]
direct_dict = {
"img_in.proj": "img_in.proj",
"time_in.mlp.0": "time_in.timestep_embedder.0",
"time_in.mlp.2": "time_in.timestep_embedder.2",
"vector_in.in_layer": "vector_in.0",
"vector_in.out_layer": "vector_in.2",
"guidance_in.mlp.0": "guidance_in.timestep_embedder.0",
"guidance_in.mlp.2": "guidance_in.timestep_embedder.2",
"txt_in.input_embedder": "txt_in.input_embedder",
"txt_in.t_embedder.mlp.0": "txt_in.t_embedder.timestep_embedder.0",
"txt_in.t_embedder.mlp.2": "txt_in.t_embedder.timestep_embedder.2",
"txt_in.c_embedder.linear_1": "txt_in.c_embedder.0",
"txt_in.c_embedder.linear_2": "txt_in.c_embedder.2",
"final_layer.linear": "final_layer.linear",
"final_layer.adaLN_modulation.1": "final_layer.adaLN_modulation.1",
}
txt_suffix_dict = {
"norm1": "norm1",
"self_attn_qkv": "self_attn_qkv",
"self_attn_proj": "self_attn_proj",
"norm2": "norm2",
"mlp.fc1": "mlp.0",
"mlp.fc2": "mlp.2",
"adaLN_modulation.1": "adaLN_modulation.1",
}
double_suffix_dict = {
"img_mod.linear": "component_a.mod.linear",
"img_attn_qkv": "component_a.to_qkv",
"img_attn_q_norm": "component_a.norm_q",
"img_attn_k_norm": "component_a.norm_k",
"img_attn_proj": "component_a.to_out",
"img_mlp.fc1": "component_a.ff.0",
"img_mlp.fc2": "component_a.ff.2",
"txt_mod.linear": "component_b.mod.linear",
"txt_attn_qkv": "component_b.to_qkv",
"txt_attn_q_norm": "component_b.norm_q",
"txt_attn_k_norm": "component_b.norm_k",
"txt_attn_proj": "component_b.to_out",
"txt_mlp.fc1": "component_b.ff.0",
"txt_mlp.fc2": "component_b.ff.2",
}
single_suffix_dict = {
"linear1": ["to_qkv", "ff.0"],
"linear2": ["to_out", "ff.2"],
"q_norm": "norm_q",
"k_norm": "norm_k",
"modulation.linear": "mod.linear",
}
# single_suffix_dict = {
# "linear1": "linear1",
# "linear2": "linear2",
# "q_norm": "q_norm",
# "k_norm": "k_norm",
# "modulation.linear": "modulation.linear",
# }
state_dict_ = {}
for name, param in state_dict.items():
names = name.split(".")
direct_name = ".".join(names[:-1])
if direct_name in direct_dict:
name_ = direct_dict[direct_name] + "." + names[-1]
state_dict_[name_] = param
elif names[0] == "double_blocks":
prefix = ".".join(names[:2])
suffix = ".".join(names[2:-1])
name_ = prefix + "." + double_suffix_dict[suffix] + "." + names[-1]
state_dict_[name_] = param
elif names[0] == "single_blocks":
prefix = ".".join(names[:2])
suffix = ".".join(names[2:-1])
if isinstance(single_suffix_dict[suffix], list):
if suffix == "linear1":
name_a, name_b = single_suffix_dict[suffix]
param_a, param_b = torch.split(param, (3072*3, 3072*4), dim=0)
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
elif suffix == "linear2":
if names[-1] == "weight":
name_a, name_b = single_suffix_dict[suffix]
param_a, param_b = torch.split(param, (3072*1, 3072*4), dim=-1)
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
else:
name_a, name_b = single_suffix_dict[suffix]
state_dict_[prefix + "." + name_a + "." + names[-1]] = param
else:
pass
else:
name_ = prefix + "." + single_suffix_dict[suffix] + "." + names[-1]
state_dict_[name_] = param
elif names[0] == "txt_in":
prefix = ".".join(names[:4]).replace(".individual_token_refiner.", ".")
suffix = ".".join(names[4:-1])
name_ = prefix + "." + txt_suffix_dict[suffix] + "." + names[-1]
state_dict_[name_] = param
else:
pass
return state_dict_

View File

@@ -0,0 +1,55 @@
from transformers import LlamaModel, LlamaConfig, DynamicCache
from copy import deepcopy
import torch
class HunyuanVideoLLMEncoder(LlamaModel):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.auto_offload = False
def enable_auto_offload(self, **kwargs):
self.auto_offload = True
def forward(
self,
input_ids,
attention_mask,
hidden_state_skip_layer=2
):
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
inputs_embeds = embed_tokens(input_ids)
past_key_values = DynamicCache()
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
position_embeddings = rotary_emb(hidden_states, position_ids)
# decoder layers
for layer_id, decoder_layer in enumerate(self.layers):
if self.auto_offload:
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
break
return hidden_states

View File

@@ -0,0 +1,507 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np
from tqdm import tqdm
from einops import repeat
class CausalConv3d(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride=1, dilation=1, pad_mode='replicate', **kwargs):
super().__init__()
self.pad_mode = pad_mode
self.time_causal_padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0
) # W, H, T
self.conv = nn.Conv3d(in_channel, out_channel, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class UpsampleCausal3D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, kernel_size=None, bias=True, upsample_factor=(2, 2, 2)):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.upsample_factor = upsample_factor
self.conv = None
if use_conv:
kernel_size = 3 if kernel_size is None else kernel_size
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
def forward(self, hidden_states):
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# interpolate
B, C, T, H, W = hidden_states.shape
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
if T > 1:
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
first_h = F.interpolate(first_h.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest").unsqueeze(2)
hidden_states = torch.cat((first_h, other_h), dim=2) if T > 1 else first_h
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
if self.conv:
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlockCausal3D(nn.Module):
def __init__(self, in_channels, out_channels=None, dropout=0.0, groups=32, eps=1e-6, conv_shortcut_bias=True):
super().__init__()
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1)
self.dropout = nn.Dropout(dropout)
self.nonlinearity = nn.SiLU()
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=conv_shortcut_bias)
def forward(self, input_tensor):
hidden_states = input_tensor
# conv1
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
# conv2
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
# shortcut
if self.conv_shortcut is not None:
input_tensor = (self.conv_shortcut(input_tensor))
# shortcut and scale
output_tensor = input_tensor + hidden_states
return output_tensor
def prepare_causal_attention_mask(n_frame, n_hw, dtype, device, batch_size=None):
seq_len = n_frame * n_hw
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // n_hw
mask[i, :(i_frame + 1) * n_hw] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
class Attention(nn.Module):
def __init__(self,
in_channels,
num_heads,
head_dim,
num_groups=32,
dropout=0.0,
eps=1e-6,
bias=True,
residual_connection=True):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.residual_connection = residual_connection
dim_inner = head_dim * num_heads
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
self.to_q = nn.Linear(in_channels, dim_inner, bias=bias)
self.to_k = nn.Linear(in_channels, dim_inner, bias=bias)
self.to_v = nn.Linear(in_channels, dim_inner, bias=bias)
self.to_out = nn.Sequential(nn.Linear(dim_inner, in_channels, bias=bias), nn.Dropout(dropout))
def forward(self, input_tensor, attn_mask=None):
hidden_states = self.group_norm(input_tensor.transpose(1, 2)).transpose(1, 2)
batch_size = hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(hidden_states)
v = self.to_v(hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if attn_mask is not None:
attn_mask = attn_mask.view(batch_size, self.num_heads, -1, attn_mask.shape[-1])
hidden_states = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = self.to_out(hidden_states)
if self.residual_connection:
output_tensor = input_tensor + hidden_states
return output_tensor
class UNetMidBlockCausal3D(nn.Module):
def __init__(self, in_channels, dropout=0.0, num_layers=1, eps=1e-6, num_groups=32, attention_head_dim=None):
super().__init__()
resnets = [
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
groups=num_groups,
eps=eps,
)
]
attentions = []
attention_head_dim = attention_head_dim or in_channels
for _ in range(num_layers):
attentions.append(
Attention(
in_channels,
num_heads=in_channels // attention_head_dim,
head_dim=attention_head_dim,
num_groups=num_groups,
dropout=dropout,
eps=eps,
bias=True,
residual_connection=True,
))
resnets.append(
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
groups=num_groups,
eps=eps,
))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states):
hidden_states = self.resnets[0](hidden_states)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
B, C, T, H, W = hidden_states.shape
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
attn_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
hidden_states = attn(hidden_states, attn_mask=attn_mask)
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
hidden_states = resnet(hidden_states)
return hidden_states
class UpDecoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
eps=1e-6,
num_groups=32,
add_upsample=True,
upsample_scale_factor=(2, 2, 2),
):
super().__init__()
resnets = []
for i in range(num_layers):
cur_in_channel = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=cur_in_channel,
out_channels=out_channels,
groups=num_groups,
dropout=dropout,
eps=eps,
))
self.resnets = nn.ModuleList(resnets)
self.upsamplers = None
if add_upsample:
self.upsamplers = nn.ModuleList([
UpsampleCausal3D(
out_channels,
use_conv=True,
out_channels=out_channels,
upsample_factor=upsample_scale_factor,
)
])
def forward(self, hidden_states):
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class DecoderCausal3D(nn.Module):
def __init__(
self,
in_channels=16,
out_channels=3,
eps=1e-6,
dropout=0.0,
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
num_groups=32,
time_compression_ratio=4,
spatial_compression_ratio=8,
gradient_checkpointing=False,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
dropout=dropout,
eps=eps,
num_groups=num_groups,
attention_head_dim=block_out_channels[-1],
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(block_out_channels)):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
num_time_upsample_layers = int(np.log2(time_compression_ratio))
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
up_block = UpDecoderBlockCausal3D(
in_channels=prev_output_channel,
out_channels=output_channel,
dropout=dropout,
num_layers=layers_per_block + 1,
eps=eps,
num_groups=num_groups,
add_upsample=bool(add_spatial_upsample or add_time_upsample),
upsample_scale_factor=upsample_scale_factor,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups, eps=eps)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
self.gradient_checkpointing = gradient_checkpointing
def forward(self, hidden_states):
hidden_states = self.conv_in(hidden_states)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# middle
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states,
use_reentrant=False,
)
# up
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
hidden_states,
use_reentrant=False,
)
else:
# middle
hidden_states = self.mid_block(hidden_states)
# up
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
# post-process
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class HunyuanVideoVAEDecoder(nn.Module):
def __init__(
self,
in_channels=16,
out_channels=3,
eps=1e-6,
dropout=0.0,
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
num_groups=32,
time_compression_ratio=4,
spatial_compression_ratio=8,
gradient_checkpointing=False,
):
super().__init__()
self.decoder = DecoderCausal3D(
in_channels=in_channels,
out_channels=out_channels,
eps=eps,
dropout=dropout,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
num_groups=num_groups,
time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
gradient_checkpointing=gradient_checkpointing,
)
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.scaling_factor = 0.476986
def forward(self, latents):
latents = latents / self.scaling_factor
latents = self.post_quant_conv(latents)
dec = self.decoder(latents)
return dec
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, T, H, W = data.shape
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
mask = torch.stack([t, h, w]).min(dim=0).values
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tile_forward(self, hidden_states, tile_size, tile_stride):
B, C, T, H, W = hidden_states.shape
size_t, size_h, size_w = tile_size
stride_t, stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for t in range(0, T, stride_t):
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
t_, h_, w_ = t + size_t, h + size_h, w + size_w
tasks.append((t, t_, h, h_, w, w_))
# Run
torch_dtype = self.post_quant_conv.weight.dtype
data_device = hidden_states.device
computation_device = self.post_quant_conv.weight.device
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
for t, t_, h, h_, w, w_ in tqdm(tasks):
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
if t > 0:
hidden_states_batch = hidden_states_batch[:, :, 1:]
mask = self.build_mask(
hidden_states_batch,
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
).to(dtype=torch_dtype, device=data_device)
target_t = 0 if t==0 else t * 4 + 1
target_h = h * 8
target_w = w * 8
values[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
return values / weight
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
latents = latents.to(self.post_quant_conv.weight.dtype)
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
@staticmethod
def state_dict_converter():
return HunyuanVideoVAEDecoderStateDictConverter()
class HunyuanVideoVAEDecoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith('decoder.') or name.startswith('post_quant_conv.'):
state_dict_[name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,217 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
class DownsampleCausal3D(nn.Module):
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
super().__init__()
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
return hidden_states
class DownEncoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
eps=1e-6,
num_groups=32,
add_downsample=True,
downsample_stride=2,
):
super().__init__()
resnets = []
for i in range(num_layers):
cur_in_channel = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=cur_in_channel,
out_channels=out_channels,
groups=num_groups,
dropout=dropout,
eps=eps,
))
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if add_downsample:
self.downsamplers = nn.ModuleList([DownsampleCausal3D(
out_channels,
out_channels,
stride=downsample_stride,
)])
def forward(self, hidden_states):
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class EncoderCausal3D(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 16,
eps=1e-6,
dropout=0.0,
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
num_groups=32,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
gradient_checkpointing=False,
):
super().__init__()
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
num_time_downsample_layers = int(np.log2(time_compression_ratio))
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
downsample_stride_T = (2,) if add_time_downsample else (1,)
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
down_block = DownEncoderBlockCausal3D(
in_channels=input_channel,
out_channels=output_channel,
dropout=dropout,
num_layers=layers_per_block,
eps=eps,
num_groups=num_groups,
add_downsample=bool(add_spatial_downsample or add_time_downsample),
downsample_stride=downsample_stride,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
dropout=dropout,
eps=eps,
num_groups=num_groups,
attention_head_dim=block_out_channels[-1],
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
self.gradient_checkpointing = gradient_checkpointing
def forward(self, hidden_states):
hidden_states = self.conv_in(hidden_states)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down
for down_block in self.down_blocks:
torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block),
hidden_states,
use_reentrant=False,
)
# middle
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states,
use_reentrant=False,
)
else:
# down
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
# middle
hidden_states = self.mid_block(hidden_states)
# post-process
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class HunyuanVideoVAEEncoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=16,
eps=1e-6,
dropout=0.0,
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
num_groups=32,
time_compression_ratio=4,
spatial_compression_ratio=8,
gradient_checkpointing=False,
):
super().__init__()
self.encoder = EncoderCausal3D(
in_channels=in_channels,
out_channels=out_channels,
eps=eps,
dropout=dropout,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
num_groups=num_groups,
time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
gradient_checkpointing=gradient_checkpointing,
)
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
def forward(self, images):
latents = self.encoder(images)
latents = self.quant_conv(latents)
# latents: (B C T H W)
return latents
@staticmethod
def state_dict_converter():
return HunyuanVideoVAEEncoderStateDictConverter()
class HunyuanVideoVAEEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith('encoder.') or name.startswith('quant_conv.'):
state_dict_[name] = state_dict[name]
return state_dict_

View File

@@ -7,6 +7,7 @@ from .sd3_dit import SD3DiT
from .flux_dit import FluxDiT
from .hunyuan_dit import HunyuanDiT
from .cog_dit import CogDiT
from .hunyuan_video_dit import HunyuanVideoDiT
@@ -259,6 +260,14 @@ class GeneralLoRAFromPeft:
return None
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [HunyuanVideoDiT]
self.lora_prefix = ["diffusion_model."]
self.special_keys = {}
class FluxLoRAConverter:
def __init__(self):
pass
@@ -355,4 +364,4 @@ class FluxLoRAConverter:
def get_lora_loaders():
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()]
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]

View File

@@ -35,6 +35,8 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder2

View File

@@ -52,9 +52,9 @@ class PatchEmbed(torch.nn.Module):
class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out):
def __init__(self, dim_in, dim_out, computation_device=None):
super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)

View File

@@ -2,15 +2,17 @@ import torch
from transformers import T5EncoderModel, T5Config
from .sd_text_encoder import SDTextEncoder
from .sdxl_text_encoder import SDXLTextEncoder2, SDXLTextEncoder2StateDictConverter
class SD3TextEncoder1(SDTextEncoder):
def __init__(self, vocab_size=49408):
super().__init__(vocab_size=vocab_size)
def forward(self, input_ids, clip_skip=2):
def forward(self, input_ids, clip_skip=2, extra_mask=None):
embeds = self.token_embedding(input_ids) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
if extra_mask is not None:
attn_mask[:, extra_mask[0]==0] = float("-inf")
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):

View File

@@ -44,6 +44,7 @@ def get_timestep_embedding(
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
computation_device = None,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
@@ -57,11 +58,11 @@ def get_timestep_embedding(
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = torch.exp(exponent).to(timesteps.device)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
@@ -81,11 +82,12 @@ def get_timestep_embedding(
class TemporalTimesteps(torch.nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.computation_device = computation_device
def forward(self, timesteps):
t_emb = get_timestep_embedding(
@@ -93,6 +95,7 @@ class TemporalTimesteps(torch.nn.Module):
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
computation_device=self.computation_device,
)
return t_emb

View File

@@ -80,7 +80,7 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):

View File

@@ -9,4 +9,5 @@ from .flux_image import FluxImagePipeline
from .cog_video import CogVideoPipeline
from .omnigen_image import OmnigenImagePipeline
from .pipeline_runner import SDVideoPipelineRunner
from .hunyuan_video import HunyuanVideoPipeline
KolorsImagePipeline = SDXLImagePipeline

View File

@@ -0,0 +1,139 @@
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder
from ..models.hunyuan_video_dit import HunyuanVideoDiT
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter
import torch
from einops import rearrange
import numpy as np
from PIL import Image
class HunyuanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
self.prompter = HunyuanVideoPrompter()
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: HunyuanVideoLLMEncoder = None
self.dit: HunyuanVideoDiT = None
self.vae_decoder: HunyuanVideoVAEDecoder = None
self.vae_encoder: HunyuanVideoVAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
self.vram_management = False
def enable_vram_management(self):
self.vram_management = True
self.enable_cpu_offload()
self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
def fetch_models(self, model_manager: ModelManager):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
self.dit = model_manager.fetch_model("hunyuan_video_dit")
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
self.vae_encoder = model_manager.fetch_model("hunyuan_video_vae_encoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, enable_vram_management=True):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
if enable_vram_management:
pipe.enable_vram_management()
return pipe
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
def prepare_extra_input(self, latents=None, guidance=1.0):
freqs_cos, freqs_sin = self.dit.prepare_freqs(latents)
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
def encode_video(self, frames):
# frames : (B, C, T, H, W)
latents = self.vae_encoder(frames)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
seed=None,
height=720,
width=1280,
num_frames=129,
embedded_guidance=6.0,
cfg_scale=1.0,
num_inference_steps=30,
tile_size=(17, 30, 30),
tile_stride=(12, 20, 20),
progress_bar_cmd=lambda x: x,
progress_bar_st=None,
):
# Initialize noise
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Scheduler
self.scheduler.set_timesteps(num_inference_steps)
# Denoise
self.load_models_to_device([] if self.vram_management else ["dit"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
# Inference
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Tiler parameters
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
# Decode
self.load_models_to_device(['vae_decoder'])
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
self.load_models_to_device([])
frames = self.tensor2video(frames[0])
return frames

View File

@@ -7,3 +7,4 @@ from .kolors_prompter import KolorsPrompter
from .flux_prompter import FluxPrompter
from .omost import OmostPromter
from .cog_prompter import CogPrompter
from .hunyuan_video_prompter import HunyuanVideoPrompter

View File

@@ -0,0 +1,143 @@
from .base_prompter import BasePrompter
from ..models.sd3_text_encoder import SD3TextEncoder1
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
from transformers import CLIPTokenizer, LlamaTokenizerFast
import os, torch
PROMPT_TEMPLATE_ENCODE = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
PROMPT_TEMPLATE_ENCODE_VIDEO = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
PROMPT_TEMPLATE = {
"dit-llm-encode": {
"template": PROMPT_TEMPLATE_ENCODE,
"crop_start": 36,
},
"dit-llm-encode-video": {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
"crop_start": 95,
},
}
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
class HunyuanVideoPrompter(BasePrompter):
def __init__(
self,
tokenizer_1_path=None,
tokenizer_2_path=None,
):
if tokenizer_1_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_1_path = os.path.join(
base_path, "tokenizer_configs/hunyuan_video/tokenizer_1")
if tokenizer_2_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_2_path = os.path.join(
base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
super().__init__()
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: HunyuanVideoLLMEncoder = None
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
def apply_text_to_template(self, text, template):
assert isinstance(template, str)
if isinstance(text, list):
return [self.apply_text_to_template(text_) for text_ in text]
elif isinstance(text, str):
# Will send string to tokenizer. Used for llm
return template.format(text)
else:
raise TypeError(f"Unsupported prompt type: {type(text)}")
def encode_prompt_using_clip(self, prompt, max_length, device):
tokenized_result = self.tokenizer_1(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True
)
input_ids = tokenized_result.input_ids.to(device)
attention_mask = tokenized_result.attention_mask.to(device)
return self.text_encoder_1(input_ids=input_ids, extra_mask=attention_mask)[0]
def encode_prompt_using_llm(self,
prompt,
max_length,
device,
crop_start,
hidden_state_skip_layer=2,
use_attention_mask=True):
max_length += crop_start
inputs = self.tokenizer_2(prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer)
# crop out
if crop_start > 0:
last_hidden_state = last_hidden_state[:, crop_start:]
attention_mask = (attention_mask[:, crop_start:] if use_attention_mask else None)
return last_hidden_state, attention_mask
def encode_prompt(self,
prompt,
positive=True,
device="cuda",
clip_sequence_length=77,
llm_sequence_length=256,
data_type='video',
use_template=True,
hidden_state_skip_layer=2,
use_attention_mask=True):
prompt = self.process_prompt(prompt, positive=positive)
# apply template
if use_template:
template = self.prompt_template_video if data_type == 'video' else self.prompt_template
prompt_formated = self.apply_text_to_template(prompt, template['template'])
else:
prompt_formated = prompt
# Text encoder
if data_type == 'video':
crop_start = self.prompt_template_video.get("crop_start", 0)
else:
crop_start = self.prompt_template.get("crop_start", 0)
# CLIP
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
# LLM
prompt_emb, attention_mask = self.encode_prompt_using_llm(
prompt_formated, llm_sequence_length, device, crop_start,
hidden_state_skip_layer, use_attention_mask)
return prompt_emb, pooled_prompt_emb, attention_mask

View File

@@ -4,18 +4,22 @@ import torch
class FlowMatchScheduler():
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False):
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False):
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.inverse_timesteps = inverse_timesteps
self.extra_one_step = extra_one_step
self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
else:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
if self.inverse_timesteps:
self.sigmas = torch.flip(self.sigmas, dims=[0])
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

View File

@@ -0,0 +1,30 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"49406": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"49407": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|startoftext|>",
"clean_up_tokenization_spaces": true,
"do_lower_case": true,
"eos_token": "<|endoftext|>",
"errors": "replace",
"model_max_length": 77,
"pad_token": "<|endoftext|>",
"tokenizer_class": "CLIPTokenizer",
"unk_token": "<|endoftext|>"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<|begin_of_text|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|end_of_text|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
# HunyuanVideo
HunyuanVideo is a video generation model trained by Tencent. We provide advanced VRAM management for this model, including three stages:
|VRAM required|Example script|Frames|Resolution|Note|
|-|-|-|-|-|
|80G|[hunyuanvideo_80G.py](hunyuanvideo_80G.py)|129|720*1280|No VRAM management.|
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
## Gallery
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9
Video generated by [hunyuanvideo_6G.py](hunyuanvideo_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817

View File

@@ -0,0 +1,42 @@
import torch
torch.cuda.set_per_process_memory_fraction(1.0, 0)
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
download_models(["HunyuanVideo"])
model_manager = ModelManager()
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16,
device="cpu"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideo/text_encoder_2",
"models/HunyuanVideo/vae/pytorch_model.pt",
],
torch_dtype=torch.float16,
device="cpu"
)
# We support LoRA inference. You can use the following code to load your LoRA model.
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(
model_manager,
torch_dtype=torch.bfloat16,
device="cuda"
)
# Enjoy!
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
video = pipe(prompt, seed=0)
save_video(video, "video_girl.mp4", fps=30, quality=6)

View File

@@ -0,0 +1,47 @@
import torch
torch.cuda.set_per_process_memory_fraction(1.0, 0)
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video, FlowMatchScheduler
download_models(["HunyuanVideo"])
model_manager = ModelManager()
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16,
device="cpu"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideo/text_encoder_2",
"models/HunyuanVideo/vae/pytorch_model.pt",
],
torch_dtype=torch.float16,
device="cpu"
)
# We support LoRA inference. You can use the following code to load your LoRA model.
# Example LoRA: https://civitai.com/models/1032126/walking-animation-hunyuan-video
model_manager.load_lora("models/lora/kxsr_walking_anim_v1-5.safetensors", lora_alpha=1.0)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(
model_manager,
torch_dtype=torch.bfloat16,
device="cuda"
)
# This LoRA requires shift=9.0.
pipe.scheduler = FlowMatchScheduler(shift=9.0, sigma_min=0.0, extra_one_step=True)
# Enjoy!
for clothes_up in ["white t-shirt", "black t-shirt", "orange t-shirt"]:
for clothes_down in ["blue sports skirt", "red sports skirt", "white sports skirt"]:
prompt = f"kxsr, full body, no crop, A 3D-rendered CG animation video featuring a Gorgeous, mature, curvaceous, fair-skinned female girl with long silver hair and blue eyes. She wears a {clothes_up} and a {clothes_down}, walking offering a sense of fluid movement and vivid animation."
video = pipe(prompt, seed=0, height=512, width=384, num_frames=129, num_inference_steps=18, tile_size=(17, 16, 16), tile_stride=(12, 12, 12))
save_video(video, f"video-{clothes_up}-{clothes_down}.mp4", fps=30, quality=6)

View File

@@ -0,0 +1,45 @@
import torch
torch.cuda.set_per_process_memory_fraction(1.0, 0)
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
download_models(["HunyuanVideo"])
model_manager = ModelManager()
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16,
device="cuda"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideo/text_encoder_2",
"models/HunyuanVideo/vae/pytorch_model.pt",
],
torch_dtype=torch.float16,
device="cuda"
)
# We support LoRA inference. You can use the following code to load your LoRA model.
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(
model_manager,
torch_dtype=torch.bfloat16,
device="cuda",
enable_vram_management=False
)
# Although you have enough VRAM, we still recommend you to enable offload.
pipe.enable_cpu_offload()
# Enjoy!
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
video = pipe(prompt, seed=0)
save_video(video, "video.mp4", fps=30, quality=6)