support hunyuanvideo_i2v

This commit is contained in:
mi804
2025-03-11 16:20:09 +08:00
parent 945b43492e
commit 4bec2983a9
9 changed files with 327 additions and 161 deletions

View File

@@ -19,7 +19,7 @@ Until now, DiffSynth Studio has supported the following models:
* [Wan-Video](https://github.com/Wan-Video/Wan2.1) * [Wan-Video](https://github.com/Wan-Video/Wan2.1)
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V) * [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) * [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b) * [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev) * [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) * [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
@@ -36,6 +36,7 @@ Until now, DiffSynth Studio has supported the following models:
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5) * [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
## News ## News
- **March 25, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/). - **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).

View File

@@ -112,7 +112,6 @@ model_loader_configs = [
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"), (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "ae3c22aaa28bfae6f3688f796c9814ae", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"), (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"), (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"), (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),

View File

@@ -282,7 +282,12 @@ class ModulateDiT(torch.nn.Module):
return self.linear(self.act(x)) return self.linear(self.act(x))
def modulate(x, shift=None, scale=None): def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
if tr_shift is not None:
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = torch.concat((x_zero, x_orig), dim=1)
return x
if scale is None and shift is None: if scale is None and shift is None:
return x return x
elif shift is None: elif shift is None:
@@ -386,6 +391,15 @@ def attention(q, k, v):
return x return x
def apply_gate(x, gate, tr_gate=None, tr_token=None):
if tr_gate is not None:
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
return torch.concat((x_zero, x_orig), dim=1)
else:
return x * gate.unsqueeze(1)
class MMDoubleStreamBlockComponent(torch.nn.Module): class MMDoubleStreamBlockComponent(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4): def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__() super().__init__()
@@ -406,11 +420,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size) torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
) )
def forward(self, hidden_states, conditioning, freqs_cis=None): def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1) mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
if token_replace_vec is not None:
assert tr_token is not None
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
else:
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale) norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
qkv = self.to_qkv(norm_hidden_states) qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
@@ -419,13 +439,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
if freqs_cis is not None: if freqs_cis is not None:
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False) q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate) def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
def process_ff(self, hidden_states, attn_output, mod):
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1) if mod_tr is not None:
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1) tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
else:
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
return hidden_states return hidden_states
@@ -435,18 +459,18 @@ class MMDoubleStreamBlock(torch.nn.Module):
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio) self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio) self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis): def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis) (q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None) (q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous() q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous() k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous() v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
attn_output_a = attention(q_a, k_a, v_a) attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b) attn_output_b = attention(q_b, k_b, v_b)
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1) attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a) hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b) hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
return hidden_states_a, hidden_states_b return hidden_states_a, hidden_states_b
@@ -510,11 +534,17 @@ class MMSingleStreamBlock(torch.nn.Module):
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False) torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
) )
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256): def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1) mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
if token_replace_vec is not None:
assert tr_token is not None
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
else:
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
norm_hidden_states = self.norm(hidden_states) norm_hidden_states = self.norm(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale) norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
qkv = self.to_qkv(norm_hidden_states) qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
@@ -526,16 +556,17 @@ class MMSingleStreamBlock(torch.nn.Module):
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False) q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous() v_len = txt_len - split_token
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous() q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous() k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
attn_output_a = attention(q_a, k_a, v_a) attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b) attn_output_b = attention(q_b, k_b, v_b)
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1) attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1) hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1) hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
return hidden_states return hidden_states
@@ -780,7 +811,6 @@ class HunyuanVideoDiT(torch.nn.Module):
return HunyuanVideoDiTStateDictConverter() return HunyuanVideoDiTStateDictConverter()
class HunyuanVideoDiTStateDictConverter: class HunyuanVideoDiTStateDictConverter:
def __init__(self): def __init__(self):
pass pass
@@ -886,6 +916,5 @@ class HunyuanVideoDiTStateDictConverter:
state_dict_[name_] = param state_dict_[name_] = param
else: else:
pass pass
if origin_hash_key == "ae3c22aaa28bfae6f3688f796c9814ae":
return state_dict_, {"in_channels": 33, "guidance_embed":False}
return state_dict_ return state_dict_

View File

@@ -5,13 +5,13 @@ from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter from ..prompters import HunyuanVideoPrompter
import torch import torch
import torchvision.transforms as transforms
from einops import rearrange from einops import rearrange
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
class HunyuanVideoPipeline(BasePipeline): class HunyuanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16): def __init__(self, device="cuda", torch_dtype=torch.float16):
@@ -53,10 +53,58 @@ class HunyuanVideoPipeline(BasePipeline):
pipe.enable_vram_management() pipe.enable_vram_management()
return pipe return pipe
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size)**2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
return buckets[closest_ratio_id], float(closest_ratio)
def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
if i2v_resolution == "720p":
bucket_hw_base_size = 960
elif i2v_resolution == "540p":
bucket_hw_base_size = 720
elif i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = semantic_images[0].size
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
ref_image_transform = transforms.Compose([
transforms.Resize(closest_size),
transforms.CenterCrop(closest_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
target_height, target_width = closest_size
return semantic_image_pixel_values, target_height, target_width
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt( prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
) )
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask} return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
@@ -87,6 +135,9 @@ class HunyuanVideoPipeline(BasePipeline):
prompt, prompt,
negative_prompt="", negative_prompt="",
input_video=None, input_video=None,
input_images=None,
i2v_resolution="720p",
i2v_stability=True,
denoising_strength=1.0, denoising_strength=1.0,
seed=None, seed=None,
rand_device=None, rand_device=None,
@@ -109,6 +160,13 @@ class HunyuanVideoPipeline(BasePipeline):
# Scheduler # Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength) self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# encoder input images
if input_images is not None:
self.load_models_to_device(['vae_encoder'])
image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
image_latents = self.vae_encoder(image_pixel_values)
# Initialize noise # Initialize noise
rand_device = self.device if rand_device is None else rand_device rand_device = self.device if rand_device is None else rand_device
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device) noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
@@ -118,12 +176,18 @@ class HunyuanVideoPipeline(BasePipeline):
input_video = torch.stack(input_video, dim=2) input_video = torch.stack(input_video, dim=2)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device) latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
elif input_images is not None and i2v_stability:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
t = torch.tensor([0.999]).to(device=self.device)
latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
latents = latents.to(dtype=image_latents.dtype)
else: else:
latents = noise latents = noise
# Encode prompts # Encode prompts
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"]) # current mllm does not support vram_management
prompt_emb_posi = self.encode_prompt(prompt, positive=True) self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
if cfg_scale != 1.0: if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
@@ -139,11 +203,16 @@ class HunyuanVideoPipeline(BasePipeline):
timestep = timestep.unsqueeze(0).to(self.device) timestep = timestep.unsqueeze(0).to(self.device)
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}") print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
forward_func = lets_dance_hunyuan_video
if input_images is not None:
latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
forward_func = lets_dance_hunyuan_video_i2v
# Inference # Inference
with torch.autocast(device_type=self.device, dtype=self.torch_dtype): with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs) noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
if cfg_scale != 1.0: if cfg_scale != 1.0:
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input) noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
noise_pred = noise_pred_posi noise_pred = noise_pred_posi
@@ -163,6 +232,10 @@ class HunyuanVideoPipeline(BasePipeline):
self.load_models_to_device([] if self.vram_management else ["dit"]) self.load_models_to_device([] if self.vram_management else ["dit"])
# Scheduler # Scheduler
if input_images is not None:
latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
latents = torch.concat([image_latents, latents], dim=2)
else:
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Decode # Decode
@@ -250,13 +323,70 @@ def lets_dance_hunyuan_video(
print("TeaCache skip forward.") print("TeaCache skip forward.")
img = tea_cache.update(img) img = tea_cache.update(img)
else: else:
split_token = int(text_mask.sum(dim=1))
txt_len = int(txt.shape[1])
for block in tqdm(dit.double_blocks, desc="Double stream blocks"): for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin)) img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
x = torch.concat([img, txt], dim=1) x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"): for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin)) x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
img = x[:, :-256] img = x[:, :-txt_len]
if tea_cache is not None:
tea_cache.store(img)
img = dit.final_layer(img, vec)
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img
def lets_dance_hunyuan_video_i2v(
dit: HunyuanVideoDiT,
x: torch.Tensor,
t: torch.Tensor,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
tea_cache: TeaCache = None,
**kwargs
):
B, C, T, H, W = x.shape
# Uncomment below to keep same as official implementation
# guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
vec = dit.time_in(t, dtype=torch.bfloat16)
vec_2 = dit.vector_in(pooled_prompt_emb)
vec = vec + vec_2
vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
tr_token = (H // 2) * (W // 2)
token_replace_vec = token_replace_vec + vec_2
img = dit.img_in(x)
txt = dit.txt_in(prompt_emb, t, text_mask)
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, img, vec)
else:
tea_cache_update = False
if tea_cache_update:
print("TeaCache skip forward.")
img = tea_cache.update(img)
else:
split_token = int(text_mask.sum(dim=1))
txt_len = int(txt.shape[1])
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
img = x[:, :-txt_len]
if tea_cache is not None: if tea_cache is not None:
tea_cache.store(img) tea_cache.store(img)

View File

@@ -87,7 +87,6 @@ class HunyuanVideoPrompter(BasePrompter):
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right') self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
self.text_encoder_1: SD3TextEncoder1 = None self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: HunyuanVideoLLMEncoder = None self.text_encoder_2: HunyuanVideoLLMEncoder = None
self.i2v_mode = False
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode'] self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video'] self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
@@ -106,8 +105,6 @@ class HunyuanVideoPrompter(BasePrompter):
# template # template
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v'] self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v'] self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
# mode setting
self.i2v_mode = True
def apply_text_to_template(self, text, template): def apply_text_to_template(self, text, template):
assert isinstance(template, str) assert isinstance(template, str)
@@ -164,10 +161,8 @@ class HunyuanVideoPrompter(BasePrompter):
crop_start, crop_start,
hidden_state_skip_layer=2, hidden_state_skip_layer=2,
use_attention_mask=True, use_attention_mask=True,
image_embed_interleave=2): image_embed_interleave=4):
image_outputs = self.processor(images, return_tensors="pt")[ image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device)
"pixel_values"
].to(device)
max_length += crop_start max_length += crop_start
inputs = self.tokenizer_2(prompt, inputs = self.tokenizer_2(prompt,
return_tensors="pt", return_tensors="pt",
@@ -248,7 +243,8 @@ class HunyuanVideoPrompter(BasePrompter):
data_type='video', data_type='video',
use_template=True, use_template=True,
hidden_state_skip_layer=2, hidden_state_skip_layer=2,
use_attention_mask=True): use_attention_mask=True,
image_embed_interleave=4):
prompt = self.process_prompt(prompt, positive=positive) prompt = self.process_prompt(prompt, positive=positive)
@@ -273,6 +269,7 @@ class HunyuanVideoPrompter(BasePrompter):
hidden_state_skip_layer, use_attention_mask) hidden_state_skip_layer, use_attention_mask)
else: else:
prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device, prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
crop_start, hidden_state_skip_layer, use_attention_mask) crop_start, hidden_state_skip_layer, use_attention_mask,
image_embed_interleave)
return prompt_emb, pooled_prompt_emb, attention_mask return prompt_emb, pooled_prompt_emb, attention_mask

View File

@@ -8,6 +8,12 @@
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)| |24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.| |6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
[HunyuanVideo-I2V](https://github.com/Tencent/HunyuanVideo-I2V) is the image-to-video generation version of HunyuanVideo. We also provide advanced VRAM management for this model.
|VRAM required|Example script|Frames|Resolution|Note|
|-|-|-|-|-|
|80G|[hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py)|129|720p|No VRAM management.|
|24G|[hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py)|129|720p|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
## Gallery ## Gallery
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py): Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
@@ -21,3 +27,7 @@ https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video): Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10 https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
Video generated by [hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py) and [hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py):
https://github.com/user-attachments/assets/494f252a-c9af-440d-84ba-a8ddcdcc538a

View File

@@ -1,88 +0,0 @@
import torch
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
from diffsynth.prompters.hunyuan_video_prompter import HunyuanVideoPrompter
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size)**2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
return buckets[closest_ratio_id], float(closest_ratio)
def prepare_vae_inputs(semantic_images, i2v_resolution="720p"):
if i2v_resolution == "720p":
bucket_hw_base_size = 960
elif i2v_resolution == "540p":
bucket_hw_base_size = 720
elif i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = semantic_images[0].size
crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
ref_image_transform = transforms.Compose([
transforms.Resize(closest_size),
transforms.CenterCrop(closest_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2)
return semantic_image_pixel_values
model_manager = ModelManager()
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization.
device="cuda"
)
model_manager.load_models(
[
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
],
torch_dtype=torch.float16,
device="cuda"
)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(
model_manager,
torch_dtype=torch.bfloat16,
device="cuda",
enable_vram_management=False
)
# Although you have enough VRAM, we still recommend you to enable offload.
pipe.enable_cpu_offload()
print()

View File

@@ -0,0 +1,43 @@
import torch
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
from modelscope import dataset_snapshot_download
from PIL import Image
download_models(["HunyuanVideoI2V"])
model_manager = ModelManager()
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16,
device="cpu"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
],
torch_dtype=torch.float16,
device="cpu"
)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
torch_dtype=torch.bfloat16,
device="cuda",
enable_vram_management=True)
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/hunyuanvideo/*")
i2v_resolution = "720p"
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
save_video(video, f"video_{i2v_resolution}_low_vram.mp4", fps=30, quality=6)

View File

@@ -0,0 +1,45 @@
import torch
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
from modelscope import dataset_snapshot_download
from PIL import Image
download_models(["HunyuanVideoI2V"])
model_manager = ModelManager()
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16,
device="cuda"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
],
torch_dtype=torch.float16,
device="cuda"
)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
torch_dtype=torch.bfloat16,
device="cuda",
enable_vram_management=False)
# Although you have enough VRAM, we still recommend you to enable offload.
pipe.enable_cpu_offload()
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/hunyuanvideo/*")
i2v_resolution = "720p"
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
save_video(video, f"video_{i2v_resolution}.mp4", fps=30, quality=6)