Merge pull request #436 from mi804/hunyuanvideo_i2v

support hunyuanvideo-i2v
This commit is contained in:
Zhongjie Duan
2025-03-13 19:38:11 +08:00
committed by GitHub
10 changed files with 555 additions and 80 deletions

View File

@@ -19,7 +19,7 @@ Until now, DiffSynth Studio has supported the following models:
* [Wan-Video](https://github.com/Wan-Video/Wan2.1) * [Wan-Video](https://github.com/Wan-Video/Wan2.1)
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V) * [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) * [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b) * [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev) * [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) * [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
@@ -36,6 +36,7 @@ Until now, DiffSynth Studio has supported the following models:
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5) * [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
## News ## News
- **March 25, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/). - **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).

View File

@@ -135,6 +135,7 @@ huggingface_model_loader_configs = [
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"), ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"), ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"), ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"), ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
] ]
patch_model_loader_configs = [ patch_model_loader_configs = [
@@ -677,6 +678,25 @@ preset_models_on_modelscope = {
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt" "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
], ],
}, },
"HunyuanVideoI2V":{
"file_list": [
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
],
"load_path": [
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
},
"HunyuanVideo-fp8":{ "HunyuanVideo-fp8":{
"file_list": [ "file_list": [
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"), ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
@@ -753,4 +773,5 @@ Preset_model_id: TypeAlias = Literal[
"StableDiffusion3.5-medium", "StableDiffusion3.5-medium",
"HunyuanVideo", "HunyuanVideo",
"HunyuanVideo-fp8", "HunyuanVideo-fp8",
"HunyuanVideoI2V",
] ]

View File

@@ -4,6 +4,7 @@ from .utils import init_weights_on_device
from einops import rearrange, repeat from einops import rearrange, repeat
from tqdm import tqdm from tqdm import tqdm
from typing import Union, Tuple, List from typing import Union, Tuple, List
from .utils import hash_state_dict_keys
def HunyuanVideoRope(latents): def HunyuanVideoRope(latents):
@@ -281,7 +282,12 @@ class ModulateDiT(torch.nn.Module):
return self.linear(self.act(x)) return self.linear(self.act(x))
def modulate(x, shift=None, scale=None): def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
if tr_shift is not None:
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = torch.concat((x_zero, x_orig), dim=1)
return x
if scale is None and shift is None: if scale is None and shift is None:
return x return x
elif shift is None: elif shift is None:
@@ -385,6 +391,15 @@ def attention(q, k, v):
return x return x
def apply_gate(x, gate, tr_gate=None, tr_token=None):
if tr_gate is not None:
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
return torch.concat((x_zero, x_orig), dim=1)
else:
return x * gate.unsqueeze(1)
class MMDoubleStreamBlockComponent(torch.nn.Module): class MMDoubleStreamBlockComponent(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4): def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__() super().__init__()
@@ -405,11 +420,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size) torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
) )
def forward(self, hidden_states, conditioning, freqs_cis=None): def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1) mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
if token_replace_vec is not None:
assert tr_token is not None
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
else:
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale) norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
qkv = self.to_qkv(norm_hidden_states) qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
@@ -418,13 +439,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
if freqs_cis is not None: if freqs_cis is not None:
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False) q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate) def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
def process_ff(self, hidden_states, attn_output, mod):
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1) if mod_tr is not None:
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1) tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
else:
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
return hidden_states return hidden_states
@@ -434,18 +459,18 @@ class MMDoubleStreamBlock(torch.nn.Module):
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio) self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio) self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis): def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis) (q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None) (q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous() q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous() k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous() v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
attn_output_a = attention(q_a, k_a, v_a) attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b) attn_output_b = attention(q_b, k_b, v_b)
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1) attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a) hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b) hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
return hidden_states_a, hidden_states_b return hidden_states_a, hidden_states_b
@@ -509,11 +534,17 @@ class MMSingleStreamBlock(torch.nn.Module):
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False) torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
) )
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256): def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1) mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
if token_replace_vec is not None:
assert tr_token is not None
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
else:
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
norm_hidden_states = self.norm(hidden_states) norm_hidden_states = self.norm(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale) norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
qkv = self.to_qkv(norm_hidden_states) qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
@@ -525,16 +556,17 @@ class MMSingleStreamBlock(torch.nn.Module):
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False) q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous() v_len = txt_len - split_token
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous() q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous() k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
attn_output_a = attention(q_a, k_a, v_a) attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b) attn_output_b = attention(q_b, k_b, v_b)
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1) attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1) hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1) hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
return hidden_states return hidden_states
@@ -555,7 +587,7 @@ class FinalLayer(torch.nn.Module):
class HunyuanVideoDiT(torch.nn.Module): class HunyuanVideoDiT(torch.nn.Module):
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40): def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
super().__init__() super().__init__()
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size) self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size) self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
@@ -565,7 +597,7 @@ class HunyuanVideoDiT(torch.nn.Module):
torch.nn.SiLU(), torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size) torch.nn.Linear(hidden_size, hidden_size)
) )
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)]) self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)]) self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
self.final_layer = FinalLayer(hidden_size) self.final_layer = FinalLayer(hidden_size)
@@ -610,7 +642,9 @@ class HunyuanVideoDiT(torch.nn.Module):
): ):
B, C, T, H, W = x.shape B, C, T, H, W = x.shape
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32) vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
if self.guidance_in is not None:
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
img = self.img_in(x) img = self.img_in(x)
txt = self.txt_in(prompt_emb, t, text_mask) txt = self.txt_in(prompt_emb, t, text_mask)
@@ -777,12 +811,12 @@ class HunyuanVideoDiT(torch.nn.Module):
return HunyuanVideoDiTStateDictConverter() return HunyuanVideoDiTStateDictConverter()
class HunyuanVideoDiTStateDictConverter: class HunyuanVideoDiTStateDictConverter:
def __init__(self): def __init__(self):
pass pass
def from_civitai(self, state_dict): def from_civitai(self, state_dict):
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
if "module" in state_dict: if "module" in state_dict:
state_dict = state_dict["module"] state_dict = state_dict["module"]
direct_dict = { direct_dict = {
@@ -882,4 +916,5 @@ class HunyuanVideoDiTStateDictConverter:
state_dict_[name_] = param state_dict_[name_] = param
else: else:
pass pass
return state_dict_ return state_dict_

View File

@@ -1,24 +1,18 @@
from transformers import LlamaModel, LlamaConfig, DynamicCache from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
from copy import deepcopy from copy import deepcopy
import torch import torch
class HunyuanVideoLLMEncoder(LlamaModel): class HunyuanVideoLLMEncoder(LlamaModel):
def __init__(self, config: LlamaConfig): def __init__(self, config: LlamaConfig):
super().__init__(config) super().__init__(config)
self.auto_offload = False self.auto_offload = False
def enable_auto_offload(self, **kwargs): def enable_auto_offload(self, **kwargs):
self.auto_offload = True self.auto_offload = True
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
def forward(
self,
input_ids,
attention_mask,
hidden_state_skip_layer=2
):
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
inputs_embeds = embed_tokens(input_ids) inputs_embeds = embed_tokens(input_ids)
@@ -53,3 +47,22 @@ class HunyuanVideoLLMEncoder(LlamaModel):
break break
return hidden_states return hidden_states
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
self.auto_offload = False
def enable_auto_offload(self, **kwargs):
self.auto_offload = True
# TODO: implement the low VRAM inference for MLLM.
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
outputs = super().forward(input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
pixel_values=pixel_values)
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
return hidden_state

View File

@@ -5,13 +5,13 @@ from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter from ..prompters import HunyuanVideoPrompter
import torch import torch
import torchvision.transforms as transforms
from einops import rearrange from einops import rearrange
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
class HunyuanVideoPipeline(BasePipeline): class HunyuanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16): def __init__(self, device="cuda", torch_dtype=torch.float16):
@@ -53,10 +53,58 @@ class HunyuanVideoPipeline(BasePipeline):
pipe.enable_vram_management() pipe.enable_vram_management()
return pipe return pipe
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size)**2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
return buckets[closest_ratio_id], float(closest_ratio)
def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
if i2v_resolution == "720p":
bucket_hw_base_size = 960
elif i2v_resolution == "540p":
bucket_hw_base_size = 720
elif i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = semantic_images[0].size
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
ref_image_transform = transforms.Compose([
transforms.Resize(closest_size),
transforms.CenterCrop(closest_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
target_height, target_width = closest_size
return semantic_image_pixel_values, target_height, target_width
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt( prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
) )
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask} return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
@@ -87,6 +135,9 @@ class HunyuanVideoPipeline(BasePipeline):
prompt, prompt,
negative_prompt="", negative_prompt="",
input_video=None, input_video=None,
input_images=None,
i2v_resolution="720p",
i2v_stability=True,
denoising_strength=1.0, denoising_strength=1.0,
seed=None, seed=None,
rand_device=None, rand_device=None,
@@ -109,6 +160,13 @@ class HunyuanVideoPipeline(BasePipeline):
# Scheduler # Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength) self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# encoder input images
if input_images is not None:
self.load_models_to_device(['vae_encoder'])
image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
image_latents = self.vae_encoder(image_pixel_values)
# Initialize noise # Initialize noise
rand_device = self.device if rand_device is None else rand_device rand_device = self.device if rand_device is None else rand_device
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device) noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
@@ -118,12 +176,18 @@ class HunyuanVideoPipeline(BasePipeline):
input_video = torch.stack(input_video, dim=2) input_video = torch.stack(input_video, dim=2)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device) latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
elif input_images is not None and i2v_stability:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
t = torch.tensor([0.999]).to(device=self.device)
latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
latents = latents.to(dtype=image_latents.dtype)
else: else:
latents = noise latents = noise
# Encode prompts # Encode prompts
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"]) # current mllm does not support vram_management
prompt_emb_posi = self.encode_prompt(prompt, positive=True) self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
if cfg_scale != 1.0: if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
@@ -139,11 +203,16 @@ class HunyuanVideoPipeline(BasePipeline):
timestep = timestep.unsqueeze(0).to(self.device) timestep = timestep.unsqueeze(0).to(self.device)
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}") print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
forward_func = lets_dance_hunyuan_video
if input_images is not None:
latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
forward_func = lets_dance_hunyuan_video_i2v
# Inference # Inference
with torch.autocast(device_type=self.device, dtype=self.torch_dtype): with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs) noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
if cfg_scale != 1.0: if cfg_scale != 1.0:
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input) noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
noise_pred = noise_pred_posi noise_pred = noise_pred_posi
@@ -163,7 +232,11 @@ class HunyuanVideoPipeline(BasePipeline):
self.load_models_to_device([] if self.vram_management else ["dit"]) self.load_models_to_device([] if self.vram_management else ["dit"])
# Scheduler # Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) if input_images is not None:
latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
latents = torch.concat([image_latents, latents], dim=2)
else:
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Decode # Decode
self.load_models_to_device(['vae_decoder']) self.load_models_to_device(['vae_decoder'])
@@ -250,13 +323,70 @@ def lets_dance_hunyuan_video(
print("TeaCache skip forward.") print("TeaCache skip forward.")
img = tea_cache.update(img) img = tea_cache.update(img)
else: else:
split_token = int(text_mask.sum(dim=1))
txt_len = int(txt.shape[1])
for block in tqdm(dit.double_blocks, desc="Double stream blocks"): for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin)) img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
x = torch.concat([img, txt], dim=1) x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"): for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin)) x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
img = x[:, :-256] img = x[:, :-txt_len]
if tea_cache is not None:
tea_cache.store(img)
img = dit.final_layer(img, vec)
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img
def lets_dance_hunyuan_video_i2v(
dit: HunyuanVideoDiT,
x: torch.Tensor,
t: torch.Tensor,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
tea_cache: TeaCache = None,
**kwargs
):
B, C, T, H, W = x.shape
# Uncomment below to keep same as official implementation
# guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
vec = dit.time_in(t, dtype=torch.bfloat16)
vec_2 = dit.vector_in(pooled_prompt_emb)
vec = vec + vec_2
vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
tr_token = (H // 2) * (W // 2)
token_replace_vec = token_replace_vec + vec_2
img = dit.img_in(x)
txt = dit.txt_in(prompt_emb, t, text_mask)
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, img, vec)
else:
tea_cache_update = False
if tea_cache_update:
print("TeaCache skip forward.")
img = tea_cache.update(img)
else:
split_token = int(text_mask.sum(dim=1))
txt_len = int(txt.shape[1])
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
img = x[:, :-txt_len]
if tea_cache is not None: if tea_cache is not None:
tea_cache.store(img) tea_cache.store(img)

View File

@@ -1,8 +1,9 @@
from .base_prompter import BasePrompter from .base_prompter import BasePrompter
from ..models.sd3_text_encoder import SD3TextEncoder1 from ..models.sd3_text_encoder import SD3TextEncoder1
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder
from transformers import CLIPTokenizer, LlamaTokenizerFast from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor
import os, torch import os, torch
from typing import Union
PROMPT_TEMPLATE_ENCODE = ( PROMPT_TEMPLATE_ENCODE = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, " "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
@@ -18,6 +19,24 @@ PROMPT_TEMPLATE_ENCODE_VIDEO = (
"5. camera angles, movements, and transitions used in the video:<|eot_id|>" "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>") "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
PROMPT_TEMPLATE_ENCODE_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
PROMPT_TEMPLATE = { PROMPT_TEMPLATE = {
"dit-llm-encode": { "dit-llm-encode": {
"template": PROMPT_TEMPLATE_ENCODE, "template": PROMPT_TEMPLATE_ENCODE,
@@ -27,6 +46,22 @@ PROMPT_TEMPLATE = {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO, "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
"crop_start": 95, "crop_start": 95,
}, },
"dit-llm-encode-i2v": {
"template": PROMPT_TEMPLATE_ENCODE_I2V,
"crop_start": 36,
"image_emb_start": 5,
"image_emb_end": 581,
"image_emb_len": 576,
"double_return_token_id": 271
},
"dit-llm-encode-video-i2v": {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
"crop_start": 103,
"image_emb_start": 5,
"image_emb_end": 581,
"image_emb_len": 576,
"double_return_token_id": 271
},
} }
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
@@ -56,9 +91,20 @@ class HunyuanVideoPrompter(BasePrompter):
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode'] self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video'] self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None): def fetch_models(self,
text_encoder_1: SD3TextEncoder1 = None,
text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None):
self.text_encoder_1 = text_encoder_1 self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2 self.text_encoder_2 = text_encoder_2
if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder):
# processor
# TODO: may need to replace processor with local implementation
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path)
# template
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
def apply_text_to_template(self, text, template): def apply_text_to_template(self, text, template):
assert isinstance(template, str) assert isinstance(template, str)
@@ -107,8 +153,89 @@ class HunyuanVideoPrompter(BasePrompter):
return last_hidden_state, attention_mask return last_hidden_state, attention_mask
def encode_prompt_using_mllm(self,
prompt,
images,
max_length,
device,
crop_start,
hidden_state_skip_layer=2,
use_attention_mask=True,
image_embed_interleave=4):
image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device)
max_length += crop_start
inputs = self.tokenizer_2(prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
last_hidden_state = self.text_encoder_2(input_ids=input_ids,
attention_mask=attention_mask,
hidden_state_skip_layer=hidden_state_skip_layer,
pixel_values=image_outputs)
text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576))
image_crop_start = self.prompt_template_video.get("image_emb_start", 5)
image_crop_end = self.prompt_template_video.get("image_emb_end", 581)
batch_indices, last_double_return_token_indices = torch.where(
input_ids == self.prompt_template_video.get("double_return_token_id", 271))
if last_double_return_token_indices.shape[0] == 3:
# in case the prompt is too long
last_double_return_token_indices = torch.cat((
last_double_return_token_indices,
torch.tensor([input_ids.shape[-1]]),
))
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1])
batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1]
assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4)
assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576))
attention_mask_assistant_crop_start = (last_double_return_token_indices - 4)
attention_mask_assistant_crop_end = last_double_return_token_indices
text_last_hidden_state = []
text_attention_mask = []
image_last_hidden_state = []
image_attention_mask = []
for i in range(input_ids.shape[0]):
text_last_hidden_state.append(
torch.cat([
last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()],
last_hidden_state[i, assistant_crop_end[i].item():],
]))
text_attention_mask.append(
torch.cat([
attention_mask[
i,
crop_start:attention_mask_assistant_crop_start[i].item(),
],
attention_mask[i, attention_mask_assistant_crop_end[i].item():],
]) if use_attention_mask else None)
image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end])
image_attention_mask.append(
torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device).
to(attention_mask.dtype) if use_attention_mask else None)
text_last_hidden_state = torch.stack(text_last_hidden_state)
text_attention_mask = torch.stack(text_attention_mask)
image_last_hidden_state = torch.stack(image_last_hidden_state)
image_attention_mask = torch.stack(image_attention_mask)
image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :]
image_attention_mask = image_attention_mask[:, ::image_embed_interleave]
assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and
image_last_hidden_state.shape[0] == image_attention_mask.shape[0])
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
return last_hidden_state, attention_mask
def encode_prompt(self, def encode_prompt(self,
prompt, prompt,
images=None,
positive=True, positive=True,
device="cuda", device="cuda",
clip_sequence_length=77, clip_sequence_length=77,
@@ -116,7 +243,8 @@ class HunyuanVideoPrompter(BasePrompter):
data_type='video', data_type='video',
use_template=True, use_template=True,
hidden_state_skip_layer=2, hidden_state_skip_layer=2,
use_attention_mask=True): use_attention_mask=True,
image_embed_interleave=4):
prompt = self.process_prompt(prompt, positive=positive) prompt = self.process_prompt(prompt, positive=positive)
@@ -136,8 +264,12 @@ class HunyuanVideoPrompter(BasePrompter):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device) pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
# LLM # LLM
prompt_emb, attention_mask = self.encode_prompt_using_llm( if images is None:
prompt_formated, llm_sequence_length, device, crop_start, prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start,
hidden_state_skip_layer, use_attention_mask) hidden_state_skip_layer, use_attention_mask)
else:
prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
crop_start, hidden_state_skip_layer, use_attention_mask,
image_embed_interleave)
return prompt_emb, pooled_prompt_emb, attention_mask return prompt_emb, pooled_prompt_emb, attention_mask

View File

@@ -0,0 +1,45 @@
{
"_valid_processor_keys": [
"images",
"do_resize",
"size",
"resample",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_convert_rgb",
"return_tensors",
"data_format",
"input_data_format"
],
"crop_size": {
"height": 336,
"width": 336
},
"do_center_crop": true,
"do_convert_rgb": true,
"do_normalize": true,
"do_rescale": true,
"do_resize": true,
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_processor_type": "CLIPImageProcessor",
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"processor_class": "LlavaProcessor",
"resample": 3,
"rescale_factor": 0.00392156862745098,
"size": {
"shortest_edge": 336
}
}

View File

@@ -8,6 +8,12 @@
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)| |24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.| |6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
[HunyuanVideo-I2V](https://github.com/Tencent/HunyuanVideo-I2V) is the image-to-video generation version of HunyuanVideo. We also provide advanced VRAM management for this model.
|VRAM required|Example script|Frames|Resolution|Note|
|-|-|-|-|-|
|80G|[hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py)|129|720p|No VRAM management.|
|24G|[hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py)|129|720p|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
## Gallery ## Gallery
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py): Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
@@ -21,3 +27,7 @@ https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video): Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10 https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
Video generated by [hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py) and [hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py):
https://github.com/user-attachments/assets/494f252a-c9af-440d-84ba-a8ddcdcc538a

View File

@@ -0,0 +1,43 @@
import torch
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
from modelscope import dataset_snapshot_download
from PIL import Image
download_models(["HunyuanVideoI2V"])
model_manager = ModelManager()
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16,
device="cpu"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
],
torch_dtype=torch.float16,
device="cpu"
)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
torch_dtype=torch.bfloat16,
device="cuda",
enable_vram_management=True)
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/hunyuanvideo/*")
i2v_resolution = "720p"
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
save_video(video, f"video_{i2v_resolution}_low_vram.mp4", fps=30, quality=6)

View File

@@ -0,0 +1,45 @@
import torch
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
from modelscope import dataset_snapshot_download
from PIL import Image
download_models(["HunyuanVideoI2V"])
model_manager = ModelManager()
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16,
device="cuda"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
],
torch_dtype=torch.float16,
device="cuda"
)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
torch_dtype=torch.bfloat16,
device="cuda",
enable_vram_management=False)
# Although you have enough VRAM, we still recommend you to enable offload.
pipe.enable_cpu_offload()
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/hunyuanvideo/*")
i2v_resolution = "720p"
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
save_video(video, f"video_{i2v_resolution}.mp4", fps=30, quality=6)