mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
Merge pull request #436 from mi804/hunyuanvideo_i2v
support hunyuanvideo-i2v
This commit is contained in:
@@ -4,6 +4,7 @@ from .utils import init_weights_on_device
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Tuple, List
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
def HunyuanVideoRope(latents):
|
||||
@@ -236,7 +237,7 @@ class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class SingleTokenRefiner(torch.nn.Module):
|
||||
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
||||
@@ -269,7 +270,7 @@ class SingleTokenRefiner(torch.nn.Module):
|
||||
x = block(x, c, mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class ModulateDiT(torch.nn.Module):
|
||||
def __init__(self, hidden_size, factor=6):
|
||||
@@ -279,9 +280,14 @@ class ModulateDiT(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.act(x))
|
||||
|
||||
|
||||
def modulate(x, shift=None, scale=None):
|
||||
|
||||
def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
|
||||
if tr_shift is not None:
|
||||
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
|
||||
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
x = torch.concat((x_zero, x_orig), dim=1)
|
||||
return x
|
||||
if scale is None and shift is None:
|
||||
return x
|
||||
elif shift is None:
|
||||
@@ -290,7 +296,7 @@ def modulate(x, shift=None, scale=None):
|
||||
return x + shift.unsqueeze(1)
|
||||
else:
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
def reshape_for_broadcast(
|
||||
freqs_cis,
|
||||
@@ -343,7 +349,7 @@ def rotate_half(x):
|
||||
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
@@ -385,6 +391,15 @@ def attention(q, k, v):
|
||||
return x
|
||||
|
||||
|
||||
def apply_gate(x, gate, tr_gate=None, tr_token=None):
|
||||
if tr_gate is not None:
|
||||
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
|
||||
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
|
||||
return torch.concat((x_zero, x_orig), dim=1)
|
||||
else:
|
||||
return x * gate.unsqueeze(1)
|
||||
|
||||
|
||||
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
@@ -405,11 +420,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None):
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
|
||||
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
||||
if token_replace_vec is not None:
|
||||
assert tr_token is not None
|
||||
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
|
||||
else:
|
||||
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
|
||||
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
|
||||
@@ -418,15 +439,19 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
|
||||
if freqs_cis is not None:
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
|
||||
|
||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
|
||||
|
||||
def process_ff(self, hidden_states, attn_output, mod):
|
||||
def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
|
||||
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
||||
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
|
||||
if mod_tr is not None:
|
||||
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
|
||||
else:
|
||||
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
|
||||
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
|
||||
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
|
||||
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class MMDoubleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
@@ -434,18 +459,18 @@ class MMDoubleStreamBlock(torch.nn.Module):
|
||||
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
|
||||
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
|
||||
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
|
||||
(q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
|
||||
(q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
||||
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||
v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
|
||||
attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
|
||||
|
||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
|
||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
|
||||
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
@@ -488,7 +513,7 @@ class MMSingleStreamBlockOriginal(torch.nn.Module):
|
||||
|
||||
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
||||
return x + output * mod_gate.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
class MMSingleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
@@ -509,11 +534,17 @@ class MMSingleStreamBlock(torch.nn.Module):
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
|
||||
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
||||
if token_replace_vec is not None:
|
||||
assert tr_token is not None
|
||||
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
|
||||
else:
|
||||
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
|
||||
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
@@ -525,16 +556,17 @@ class MMSingleStreamBlock(torch.nn.Module):
|
||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
||||
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
|
||||
v_len = txt_len - split_token
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
|
||||
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||
|
||||
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
|
||||
hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -555,7 +587,7 @@ class FinalLayer(torch.nn.Module):
|
||||
|
||||
|
||||
class HunyuanVideoDiT(torch.nn.Module):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
|
||||
super().__init__()
|
||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
||||
@@ -565,7 +597,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
|
||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
||||
self.final_layer = FinalLayer(hidden_size)
|
||||
@@ -580,7 +612,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
def unpatchify(self, x, T, H, W):
|
||||
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
||||
return x
|
||||
|
||||
|
||||
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
||||
self.warm_device = warm_device
|
||||
self.cold_device = cold_device
|
||||
@@ -610,10 +642,12 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
|
||||
if self.guidance_in is not None:
|
||||
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
img = self.img_in(x)
|
||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
|
||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
@@ -625,7 +659,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
img = self.final_layer(img, vec)
|
||||
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
|
||||
|
||||
|
||||
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
@@ -681,7 +715,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
del x_, weight_, bias_
|
||||
torch.cuda.empty_cache()
|
||||
return y_
|
||||
|
||||
|
||||
def block_forward(self, x, **kwargs):
|
||||
# This feature can only reduce 2GB VRAM, so we disable it.
|
||||
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
||||
@@ -689,19 +723,19 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
||||
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
||||
return y
|
||||
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
@@ -711,30 +745,30 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
||||
hidden_states = hidden_states * weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if self.weight is not None and self.bias is not None:
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
||||
else:
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
@@ -777,12 +811,12 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
return HunyuanVideoDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
direct_dict = {
|
||||
@@ -882,4 +916,5 @@ class HunyuanVideoDiTStateDictConverter:
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
|
||||
return state_dict_
|
||||
|
||||
@@ -1,24 +1,18 @@
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
hidden_state_skip_layer=2
|
||||
):
|
||||
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
|
||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
@@ -53,3 +47,22 @@ class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
break
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
# TODO: implement the low VRAM inference for MLLM.
|
||||
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
|
||||
outputs = super().forward(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
pixel_values=pixel_values)
|
||||
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
return hidden_state
|
||||
|
||||
Reference in New Issue
Block a user