mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #436 from mi804/hunyuanvideo_i2v
support hunyuanvideo-i2v
This commit is contained in:
@@ -19,7 +19,7 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
|
||||
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
||||
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo)
|
||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
|
||||
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
@@ -36,6 +36,7 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
|
||||
## News
|
||||
- **March 25, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||
|
||||
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||
|
||||
|
||||
@@ -135,6 +135,7 @@ huggingface_model_loader_configs = [
|
||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
||||
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
||||
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
@@ -677,6 +678,25 @@ preset_models_on_modelscope = {
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
},
|
||||
"HunyuanVideoI2V":{
|
||||
"file_list": [
|
||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
|
||||
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
|
||||
],
|
||||
"load_path": [
|
||||
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideoI2V/text_encoder_2",
|
||||
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
|
||||
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
},
|
||||
"HunyuanVideo-fp8":{
|
||||
"file_list": [
|
||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
||||
@@ -753,4 +773,5 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"StableDiffusion3.5-medium",
|
||||
"HunyuanVideo",
|
||||
"HunyuanVideo-fp8",
|
||||
"HunyuanVideoI2V",
|
||||
]
|
||||
|
||||
@@ -4,6 +4,7 @@ from .utils import init_weights_on_device
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Tuple, List
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
def HunyuanVideoRope(latents):
|
||||
@@ -236,7 +237,7 @@ class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class SingleTokenRefiner(torch.nn.Module):
|
||||
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
||||
@@ -269,7 +270,7 @@ class SingleTokenRefiner(torch.nn.Module):
|
||||
x = block(x, c, mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class ModulateDiT(torch.nn.Module):
|
||||
def __init__(self, hidden_size, factor=6):
|
||||
@@ -279,9 +280,14 @@ class ModulateDiT(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.act(x))
|
||||
|
||||
|
||||
def modulate(x, shift=None, scale=None):
|
||||
|
||||
def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
|
||||
if tr_shift is not None:
|
||||
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
|
||||
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
x = torch.concat((x_zero, x_orig), dim=1)
|
||||
return x
|
||||
if scale is None and shift is None:
|
||||
return x
|
||||
elif shift is None:
|
||||
@@ -290,7 +296,7 @@ def modulate(x, shift=None, scale=None):
|
||||
return x + shift.unsqueeze(1)
|
||||
else:
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
def reshape_for_broadcast(
|
||||
freqs_cis,
|
||||
@@ -343,7 +349,7 @@ def rotate_half(x):
|
||||
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
@@ -385,6 +391,15 @@ def attention(q, k, v):
|
||||
return x
|
||||
|
||||
|
||||
def apply_gate(x, gate, tr_gate=None, tr_token=None):
|
||||
if tr_gate is not None:
|
||||
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
|
||||
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
|
||||
return torch.concat((x_zero, x_orig), dim=1)
|
||||
else:
|
||||
return x * gate.unsqueeze(1)
|
||||
|
||||
|
||||
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
@@ -405,11 +420,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None):
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
|
||||
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
||||
if token_replace_vec is not None:
|
||||
assert tr_token is not None
|
||||
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
|
||||
else:
|
||||
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
|
||||
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
|
||||
@@ -418,15 +439,19 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
|
||||
if freqs_cis is not None:
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
|
||||
|
||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
|
||||
|
||||
def process_ff(self, hidden_states, attn_output, mod):
|
||||
def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
|
||||
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
||||
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
|
||||
if mod_tr is not None:
|
||||
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
|
||||
else:
|
||||
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
|
||||
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
|
||||
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
|
||||
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class MMDoubleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
@@ -434,18 +459,18 @@ class MMDoubleStreamBlock(torch.nn.Module):
|
||||
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
|
||||
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
|
||||
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
|
||||
(q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
|
||||
(q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
||||
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||
v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
|
||||
attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
|
||||
|
||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
|
||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
|
||||
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
@@ -488,7 +513,7 @@ class MMSingleStreamBlockOriginal(torch.nn.Module):
|
||||
|
||||
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
||||
return x + output * mod_gate.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
class MMSingleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
@@ -509,11 +534,17 @@ class MMSingleStreamBlock(torch.nn.Module):
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
|
||||
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
||||
if token_replace_vec is not None:
|
||||
assert tr_token is not None
|
||||
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
|
||||
else:
|
||||
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
|
||||
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
@@ -525,16 +556,17 @@ class MMSingleStreamBlock(torch.nn.Module):
|
||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
||||
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
|
||||
v_len = txt_len - split_token
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
|
||||
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||
|
||||
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
|
||||
hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -555,7 +587,7 @@ class FinalLayer(torch.nn.Module):
|
||||
|
||||
|
||||
class HunyuanVideoDiT(torch.nn.Module):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
|
||||
super().__init__()
|
||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
||||
@@ -565,7 +597,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
|
||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
||||
self.final_layer = FinalLayer(hidden_size)
|
||||
@@ -580,7 +612,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
def unpatchify(self, x, T, H, W):
|
||||
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
||||
return x
|
||||
|
||||
|
||||
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
||||
self.warm_device = warm_device
|
||||
self.cold_device = cold_device
|
||||
@@ -610,10 +642,12 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
|
||||
if self.guidance_in is not None:
|
||||
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
img = self.img_in(x)
|
||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
|
||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
@@ -625,7 +659,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
img = self.final_layer(img, vec)
|
||||
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
|
||||
|
||||
|
||||
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
@@ -681,7 +715,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
del x_, weight_, bias_
|
||||
torch.cuda.empty_cache()
|
||||
return y_
|
||||
|
||||
|
||||
def block_forward(self, x, **kwargs):
|
||||
# This feature can only reduce 2GB VRAM, so we disable it.
|
||||
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
||||
@@ -689,19 +723,19 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
||||
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
||||
return y
|
||||
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
@@ -711,30 +745,30 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
||||
hidden_states = hidden_states * weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if self.weight is not None and self.bias is not None:
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
||||
else:
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
@@ -777,12 +811,12 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
return HunyuanVideoDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
direct_dict = {
|
||||
@@ -882,4 +916,5 @@ class HunyuanVideoDiTStateDictConverter:
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
|
||||
return state_dict_
|
||||
|
||||
@@ -1,24 +1,18 @@
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
hidden_state_skip_layer=2
|
||||
):
|
||||
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
|
||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
@@ -53,3 +47,22 @@ class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
break
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
# TODO: implement the low VRAM inference for MLLM.
|
||||
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
|
||||
outputs = super().forward(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
pixel_values=pixel_values)
|
||||
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
return hidden_state
|
||||
|
||||
@@ -5,13 +5,13 @@ from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import HunyuanVideoPrompter
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
@@ -53,10 +53,58 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
pipe.enable_vram_management()
|
||||
return pipe
|
||||
|
||||
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
|
||||
num_patches = round((base_size / patch_size)**2)
|
||||
assert max_ratio >= 1.0
|
||||
crop_size_list = []
|
||||
wp, hp = num_patches, 1
|
||||
while wp > 0:
|
||||
if max(wp, hp) / min(wp, hp) <= max_ratio:
|
||||
crop_size_list.append((wp * patch_size, hp * patch_size))
|
||||
if (hp + 1) * wp <= num_patches:
|
||||
hp += 1
|
||||
else:
|
||||
wp -= 1
|
||||
return crop_size_list
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
||||
|
||||
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
|
||||
aspect_ratio = float(height) / float(width)
|
||||
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
|
||||
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
||||
return buckets[closest_ratio_id], float(closest_ratio)
|
||||
|
||||
|
||||
def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
|
||||
if i2v_resolution == "720p":
|
||||
bucket_hw_base_size = 960
|
||||
elif i2v_resolution == "540p":
|
||||
bucket_hw_base_size = 720
|
||||
elif i2v_resolution == "360p":
|
||||
bucket_hw_base_size = 480
|
||||
else:
|
||||
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
|
||||
origin_size = semantic_images[0].size
|
||||
|
||||
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
|
||||
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
|
||||
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
|
||||
ref_image_transform = transforms.Compose([
|
||||
transforms.Resize(closest_size),
|
||||
transforms.CenterCrop(closest_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5])
|
||||
])
|
||||
|
||||
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
|
||||
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
|
||||
target_height, target_width = closest_size
|
||||
return semantic_image_pixel_values, target_height, target_width
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
|
||||
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
|
||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
|
||||
|
||||
@@ -87,6 +135,9 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_video=None,
|
||||
input_images=None,
|
||||
i2v_resolution="720p",
|
||||
i2v_stability=True,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device=None,
|
||||
@@ -105,10 +156,17 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# encoder input images
|
||||
if input_images is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
|
||||
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
|
||||
image_latents = self.vae_encoder(image_pixel_values)
|
||||
|
||||
# Initialize noise
|
||||
rand_device = self.device if rand_device is None else rand_device
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
|
||||
@@ -118,12 +176,18 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
input_video = torch.stack(input_video, dim=2)
|
||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
elif input_images is not None and i2v_stability:
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
|
||||
t = torch.tensor([0.999]).to(device=self.device)
|
||||
latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
|
||||
latents = latents.to(dtype=image_latents.dtype)
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
# current mllm does not support vram_management
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
@@ -139,11 +203,16 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||
|
||||
forward_func = lets_dance_hunyuan_video
|
||||
if input_images is not None:
|
||||
latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
|
||||
forward_func = lets_dance_hunyuan_video_i2v
|
||||
|
||||
# Inference
|
||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
||||
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
||||
noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
||||
noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
@@ -163,7 +232,11 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
if input_images is not None:
|
||||
latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
|
||||
latents = torch.concat([image_latents, latents], dim=2)
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
@@ -194,7 +267,7 @@ class TeaCache:
|
||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
else:
|
||||
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
@@ -203,14 +276,14 @@ class TeaCache:
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.step += 1
|
||||
if self.step == self.num_inference_steps:
|
||||
self.step = 0
|
||||
if should_calc:
|
||||
self.previous_hidden_states = img.clone()
|
||||
return not should_calc
|
||||
|
||||
|
||||
def store(self, hidden_states):
|
||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
@@ -250,13 +323,70 @@ def lets_dance_hunyuan_video(
|
||||
print("TeaCache skip forward.")
|
||||
img = tea_cache.update(img)
|
||||
else:
|
||||
split_token = int(text_mask.sum(dim=1))
|
||||
txt_len = int(txt.shape[1])
|
||||
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
||||
img = x[:, :-256]
|
||||
x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
|
||||
img = x[:, :-txt_len]
|
||||
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(img)
|
||||
img = dit.final_layer(img, vec)
|
||||
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
|
||||
|
||||
def lets_dance_hunyuan_video_i2v(
|
||||
dit: HunyuanVideoDiT,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
prompt_emb: torch.Tensor = None,
|
||||
text_mask: torch.Tensor = None,
|
||||
pooled_prompt_emb: torch.Tensor = None,
|
||||
freqs_cos: torch.Tensor = None,
|
||||
freqs_sin: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
# Uncomment below to keep same as official implementation
|
||||
# guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
|
||||
vec = dit.time_in(t, dtype=torch.bfloat16)
|
||||
vec_2 = dit.vector_in(pooled_prompt_emb)
|
||||
vec = vec + vec_2
|
||||
vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
|
||||
|
||||
token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
|
||||
tr_token = (H // 2) * (W // 2)
|
||||
token_replace_vec = token_replace_vec + vec_2
|
||||
|
||||
img = dit.img_in(x)
|
||||
txt = dit.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
tea_cache_update = tea_cache.check(dit, img, vec)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
if tea_cache_update:
|
||||
print("TeaCache skip forward.")
|
||||
img = tea_cache.update(img)
|
||||
else:
|
||||
split_token = int(text_mask.sum(dim=1))
|
||||
txt_len = int(txt.shape[1])
|
||||
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
|
||||
img = x[:, :-txt_len]
|
||||
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(img)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor
|
||||
import os, torch
|
||||
from typing import Union
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
@@ -18,6 +19,24 @@ PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"4. background environment, light, style and atmosphere."
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
PROMPT_TEMPLATE = {
|
||||
"dit-llm-encode": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE,
|
||||
@@ -27,6 +46,22 @@ PROMPT_TEMPLATE = {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
||||
"crop_start": 95,
|
||||
},
|
||||
"dit-llm-encode-i2v": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_I2V,
|
||||
"crop_start": 36,
|
||||
"image_emb_start": 5,
|
||||
"image_emb_end": 581,
|
||||
"image_emb_len": 576,
|
||||
"double_return_token_id": 271
|
||||
},
|
||||
"dit-llm-encode-video-i2v": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
|
||||
"crop_start": 103,
|
||||
"image_emb_start": 5,
|
||||
"image_emb_end": 581,
|
||||
"image_emb_len": 576,
|
||||
"double_return_token_id": 271
|
||||
},
|
||||
}
|
||||
|
||||
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
||||
@@ -56,9 +91,20 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
||||
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
|
||||
def fetch_models(self,
|
||||
text_encoder_1: SD3TextEncoder1 = None,
|
||||
text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder):
|
||||
# processor
|
||||
# TODO: may need to replace processor with local implementation
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
|
||||
self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path)
|
||||
# template
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
|
||||
|
||||
def apply_text_to_template(self, text, template):
|
||||
assert isinstance(template, str)
|
||||
@@ -107,8 +153,89 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
|
||||
return last_hidden_state, attention_mask
|
||||
|
||||
def encode_prompt_using_mllm(self,
|
||||
prompt,
|
||||
images,
|
||||
max_length,
|
||||
device,
|
||||
crop_start,
|
||||
hidden_state_skip_layer=2,
|
||||
use_attention_mask=True,
|
||||
image_embed_interleave=4):
|
||||
image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device)
|
||||
max_length += crop_start
|
||||
inputs = self.tokenizer_2(prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True)
|
||||
input_ids = inputs.input_ids.to(device)
|
||||
attention_mask = inputs.attention_mask.to(device)
|
||||
last_hidden_state = self.text_encoder_2(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
hidden_state_skip_layer=hidden_state_skip_layer,
|
||||
pixel_values=image_outputs)
|
||||
|
||||
text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576))
|
||||
image_crop_start = self.prompt_template_video.get("image_emb_start", 5)
|
||||
image_crop_end = self.prompt_template_video.get("image_emb_end", 581)
|
||||
batch_indices, last_double_return_token_indices = torch.where(
|
||||
input_ids == self.prompt_template_video.get("double_return_token_id", 271))
|
||||
if last_double_return_token_indices.shape[0] == 3:
|
||||
# in case the prompt is too long
|
||||
last_double_return_token_indices = torch.cat((
|
||||
last_double_return_token_indices,
|
||||
torch.tensor([input_ids.shape[-1]]),
|
||||
))
|
||||
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
|
||||
last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1])
|
||||
batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1]
|
||||
assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4)
|
||||
assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576))
|
||||
attention_mask_assistant_crop_start = (last_double_return_token_indices - 4)
|
||||
attention_mask_assistant_crop_end = last_double_return_token_indices
|
||||
text_last_hidden_state = []
|
||||
text_attention_mask = []
|
||||
image_last_hidden_state = []
|
||||
image_attention_mask = []
|
||||
for i in range(input_ids.shape[0]):
|
||||
text_last_hidden_state.append(
|
||||
torch.cat([
|
||||
last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()],
|
||||
last_hidden_state[i, assistant_crop_end[i].item():],
|
||||
]))
|
||||
text_attention_mask.append(
|
||||
torch.cat([
|
||||
attention_mask[
|
||||
i,
|
||||
crop_start:attention_mask_assistant_crop_start[i].item(),
|
||||
],
|
||||
attention_mask[i, attention_mask_assistant_crop_end[i].item():],
|
||||
]) if use_attention_mask else None)
|
||||
image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end])
|
||||
image_attention_mask.append(
|
||||
torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device).
|
||||
to(attention_mask.dtype) if use_attention_mask else None)
|
||||
|
||||
text_last_hidden_state = torch.stack(text_last_hidden_state)
|
||||
text_attention_mask = torch.stack(text_attention_mask)
|
||||
image_last_hidden_state = torch.stack(image_last_hidden_state)
|
||||
image_attention_mask = torch.stack(image_attention_mask)
|
||||
|
||||
image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :]
|
||||
image_attention_mask = image_attention_mask[:, ::image_embed_interleave]
|
||||
|
||||
assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and
|
||||
image_last_hidden_state.shape[0] == image_attention_mask.shape[0])
|
||||
|
||||
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
|
||||
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
|
||||
|
||||
return last_hidden_state, attention_mask
|
||||
|
||||
def encode_prompt(self,
|
||||
prompt,
|
||||
images=None,
|
||||
positive=True,
|
||||
device="cuda",
|
||||
clip_sequence_length=77,
|
||||
@@ -116,7 +243,8 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
data_type='video',
|
||||
use_template=True,
|
||||
hidden_state_skip_layer=2,
|
||||
use_attention_mask=True):
|
||||
use_attention_mask=True,
|
||||
image_embed_interleave=4):
|
||||
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
@@ -136,8 +264,12 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
|
||||
|
||||
# LLM
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
||||
prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
if images is None:
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
else:
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
|
||||
crop_start, hidden_state_skip_layer, use_attention_mask,
|
||||
image_embed_interleave)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
{
|
||||
"_valid_processor_keys": [
|
||||
"images",
|
||||
"do_resize",
|
||||
"size",
|
||||
"resample",
|
||||
"do_center_crop",
|
||||
"crop_size",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_convert_rgb",
|
||||
"return_tensors",
|
||||
"data_format",
|
||||
"input_data_format"
|
||||
],
|
||||
"crop_size": {
|
||||
"height": 336,
|
||||
"width": 336
|
||||
},
|
||||
"do_center_crop": true,
|
||||
"do_convert_rgb": true,
|
||||
"do_normalize": true,
|
||||
"do_rescale": true,
|
||||
"do_resize": true,
|
||||
"image_mean": [
|
||||
0.48145466,
|
||||
0.4578275,
|
||||
0.40821073
|
||||
],
|
||||
"image_processor_type": "CLIPImageProcessor",
|
||||
"image_std": [
|
||||
0.26862954,
|
||||
0.26130258,
|
||||
0.27577711
|
||||
],
|
||||
"processor_class": "LlavaProcessor",
|
||||
"resample": 3,
|
||||
"rescale_factor": 0.00392156862745098,
|
||||
"size": {
|
||||
"shortest_edge": 336
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,12 @@
|
||||
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
||||
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
|
||||
|
||||
[HunyuanVideo-I2V](https://github.com/Tencent/HunyuanVideo-I2V) is the image-to-video generation version of HunyuanVideo. We also provide advanced VRAM management for this model.
|
||||
|VRAM required|Example script|Frames|Resolution|Note|
|
||||
|-|-|-|-|-|
|
||||
|80G|[hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py)|129|720p|No VRAM management.|
|
||||
|24G|[hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py)|129|720p|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
||||
|
||||
## Gallery
|
||||
|
||||
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
|
||||
@@ -21,3 +27,7 @@ https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
|
||||
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
||||
|
||||
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
|
||||
|
||||
Video generated by [hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py) and [hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py):
|
||||
|
||||
https://github.com/user-attachments/assets/494f252a-c9af-440d-84ba-a8ddcdcc538a
|
||||
|
||||
43
examples/HunyuanVideo/hunyuanvideo_i2v_24G.py
Normal file
43
examples/HunyuanVideo/hunyuanvideo_i2v_24G.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
from modelscope import dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
download_models(["HunyuanVideoI2V"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideoI2V/text_encoder_2",
|
||||
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cpu"
|
||||
)
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
enable_vram_management=True)
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/hunyuanvideo/*")
|
||||
|
||||
i2v_resolution = "720p"
|
||||
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
|
||||
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
|
||||
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
|
||||
save_video(video, f"video_{i2v_resolution}_low_vram.mp4", fps=30, quality=6)
|
||||
45
examples/HunyuanVideo/hunyuanvideo_i2v_80G.py
Normal file
45
examples/HunyuanVideo/hunyuanvideo_i2v_80G.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
from modelscope import dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
download_models(["HunyuanVideoI2V"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideoI2V/text_encoder_2",
|
||||
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cuda"
|
||||
)
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
enable_vram_management=False)
|
||||
# Although you have enough VRAM, we still recommend you to enable offload.
|
||||
pipe.enable_cpu_offload()
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/hunyuanvideo/*")
|
||||
|
||||
i2v_resolution = "720p"
|
||||
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
|
||||
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
|
||||
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
|
||||
save_video(video, f"video_{i2v_resolution}.mp4", fps=30, quality=6)
|
||||
Reference in New Issue
Block a user