Merge pull request #436 from mi804/hunyuanvideo_i2v

support hunyuanvideo-i2v
This commit is contained in:
Zhongjie Duan
2025-03-13 19:38:11 +08:00
committed by GitHub
10 changed files with 555 additions and 80 deletions

View File

@@ -4,6 +4,7 @@ from .utils import init_weights_on_device
from einops import rearrange, repeat
from tqdm import tqdm
from typing import Union, Tuple, List
from .utils import hash_state_dict_keys
def HunyuanVideoRope(latents):
@@ -236,7 +237,7 @@ class IndividualTokenRefinerBlock(torch.nn.Module):
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
return x
class SingleTokenRefiner(torch.nn.Module):
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
@@ -269,7 +270,7 @@ class SingleTokenRefiner(torch.nn.Module):
x = block(x, c, mask)
return x
class ModulateDiT(torch.nn.Module):
def __init__(self, hidden_size, factor=6):
@@ -279,9 +280,14 @@ class ModulateDiT(torch.nn.Module):
def forward(self, x):
return self.linear(self.act(x))
def modulate(x, shift=None, scale=None):
def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
if tr_shift is not None:
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = torch.concat((x_zero, x_orig), dim=1)
return x
if scale is None and shift is None:
return x
elif shift is None:
@@ -290,7 +296,7 @@ def modulate(x, shift=None, scale=None):
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def reshape_for_broadcast(
freqs_cis,
@@ -343,7 +349,7 @@ def rotate_half(x):
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
@@ -385,6 +391,15 @@ def attention(q, k, v):
return x
def apply_gate(x, gate, tr_gate=None, tr_token=None):
if tr_gate is not None:
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
return torch.concat((x_zero, x_orig), dim=1)
else:
return x * gate.unsqueeze(1)
class MMDoubleStreamBlockComponent(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
@@ -405,11 +420,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
)
def forward(self, hidden_states, conditioning, freqs_cis=None):
def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
if token_replace_vec is not None:
assert tr_token is not None
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
else:
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
@@ -418,15 +439,19 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
if freqs_cis is not None:
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
def process_ff(self, hidden_states, attn_output, mod):
def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
if mod_tr is not None:
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
else:
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
return hidden_states
class MMDoubleStreamBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
@@ -434,18 +459,18 @@ class MMDoubleStreamBlock(torch.nn.Module):
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
(q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
(q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b)
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
return hidden_states_a, hidden_states_b
@@ -488,7 +513,7 @@ class MMSingleStreamBlockOriginal(torch.nn.Module):
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
return x + output * mod_gate.unsqueeze(1)
class MMSingleStreamBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
@@ -509,11 +534,17 @@ class MMSingleStreamBlock(torch.nn.Module):
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
)
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
if token_replace_vec is not None:
assert tr_token is not None
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
else:
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
norm_hidden_states = self.norm(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
@@ -525,16 +556,17 @@ class MMSingleStreamBlock(torch.nn.Module):
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
v_len = txt_len - split_token
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b)
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
return hidden_states
@@ -555,7 +587,7 @@ class FinalLayer(torch.nn.Module):
class HunyuanVideoDiT(torch.nn.Module):
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
super().__init__()
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
@@ -565,7 +597,7 @@ class HunyuanVideoDiT(torch.nn.Module):
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size)
)
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
self.final_layer = FinalLayer(hidden_size)
@@ -580,7 +612,7 @@ class HunyuanVideoDiT(torch.nn.Module):
def unpatchify(self, x, T, H, W):
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
return x
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
self.warm_device = warm_device
self.cold_device = cold_device
@@ -610,10 +642,12 @@ class HunyuanVideoDiT(torch.nn.Module):
):
B, C, T, H, W = x.shape
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
if self.guidance_in is not None:
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
img = self.img_in(x)
txt = self.txt_in(prompt_emb, t, text_mask)
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
@@ -625,7 +659,7 @@ class HunyuanVideoDiT(torch.nn.Module):
img = self.final_layer(img, vec)
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
def cast_to(weight, dtype=None, device=None, copy=False):
@@ -681,7 +715,7 @@ class HunyuanVideoDiT(torch.nn.Module):
del x_, weight_, bias_
torch.cuda.empty_cache()
return y_
def block_forward(self, x, **kwargs):
# This feature can only reduce 2GB VRAM, so we disable it.
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
@@ -689,19 +723,19 @@ class HunyuanVideoDiT(torch.nn.Module):
for j in range((self.out_features + self.block_size - 1) // self.block_size):
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
return y
def forward(self, x, **kwargs):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.linear(x, weight, bias)
class RMSNorm(torch.nn.Module):
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
super().__init__()
self.module = module
self.dtype = dtype
self.device = device
def forward(self, hidden_states, **kwargs):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
@@ -711,30 +745,30 @@ class HunyuanVideoDiT(torch.nn.Module):
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
hidden_states = hidden_states * weight
return hidden_states
class Conv3d(torch.nn.Conv3d):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def forward(self, x):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class LayerNorm(torch.nn.LayerNorm):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def forward(self, x):
if self.weight is not None and self.bias is not None:
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
else:
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
@@ -777,12 +811,12 @@ class HunyuanVideoDiT(torch.nn.Module):
return HunyuanVideoDiTStateDictConverter()
class HunyuanVideoDiTStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
if "module" in state_dict:
state_dict = state_dict["module"]
direct_dict = {
@@ -882,4 +916,5 @@ class HunyuanVideoDiTStateDictConverter:
state_dict_[name_] = param
else:
pass
return state_dict_

View File

@@ -1,24 +1,18 @@
from transformers import LlamaModel, LlamaConfig, DynamicCache
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
from copy import deepcopy
import torch
class HunyuanVideoLLMEncoder(LlamaModel):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.auto_offload = False
def enable_auto_offload(self, **kwargs):
self.auto_offload = True
def forward(
self,
input_ids,
attention_mask,
hidden_state_skip_layer=2
):
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
inputs_embeds = embed_tokens(input_ids)
@@ -53,3 +47,22 @@ class HunyuanVideoLLMEncoder(LlamaModel):
break
return hidden_states
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
self.auto_offload = False
def enable_auto_offload(self, **kwargs):
self.auto_offload = True
# TODO: implement the low VRAM inference for MLLM.
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
outputs = super().forward(input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
pixel_values=pixel_values)
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
return hidden_state