From 4bec2983a93a21b4724bc0643c058acf7d6ba93b Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 11 Mar 2025 16:20:09 +0800 Subject: [PATCH] support hunyuanvideo_i2v --- README.md | 3 +- diffsynth/configs/model_config.py | 1 - diffsynth/models/hunyuan_video_dit.py | 119 ++++++++----- diffsynth/pipelines/hunyuan_video.py | 164 ++++++++++++++++-- diffsynth/prompters/hunyuan_video_prompter.py | 15 +- examples/HunyuanVideo/README.md | 10 ++ examples/HunyuanVideo/hunyuanvideo_i2v.py | 88 ---------- examples/HunyuanVideo/hunyuanvideo_i2v_24G.py | 43 +++++ examples/HunyuanVideo/hunyuanvideo_i2v_80G.py | 45 +++++ 9 files changed, 327 insertions(+), 161 deletions(-) delete mode 100644 examples/HunyuanVideo/hunyuanvideo_i2v.py create mode 100644 examples/HunyuanVideo/hunyuanvideo_i2v_24G.py create mode 100644 examples/HunyuanVideo/hunyuanvideo_i2v_80G.py diff --git a/README.md b/README.md index a5993a4..9dd7ab8 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Until now, DiffSynth Studio has supported the following models: * [Wan-Video](https://github.com/Wan-Video/Wan2.1) * [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V) -* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) +* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]() * [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b) * [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev) * [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) @@ -36,6 +36,7 @@ Until now, DiffSynth Studio has supported the following models: * [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5) ## News +- **March 25, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details. - **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/). diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 718ba73..15dcbed 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -112,7 +112,6 @@ model_loader_configs = [ (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"), (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), - (None, "ae3c22aaa28bfae6f3688f796c9814ae", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"), (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"), (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"), diff --git a/diffsynth/models/hunyuan_video_dit.py b/diffsynth/models/hunyuan_video_dit.py index f008a87..1315536 100644 --- a/diffsynth/models/hunyuan_video_dit.py +++ b/diffsynth/models/hunyuan_video_dit.py @@ -237,7 +237,7 @@ class IndividualTokenRefinerBlock(torch.nn.Module): x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1) return x - + class SingleTokenRefiner(torch.nn.Module): def __init__(self, in_channels=4096, hidden_size=3072, depth=2): @@ -270,7 +270,7 @@ class SingleTokenRefiner(torch.nn.Module): x = block(x, c, mask) return x - + class ModulateDiT(torch.nn.Module): def __init__(self, hidden_size, factor=6): @@ -280,9 +280,14 @@ class ModulateDiT(torch.nn.Module): def forward(self, x): return self.linear(self.act(x)) - -def modulate(x, shift=None, scale=None): + +def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None): + if tr_shift is not None: + x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1) + x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + x = torch.concat((x_zero, x_orig), dim=1) + return x if scale is None and shift is None: return x elif shift is None: @@ -291,7 +296,7 @@ def modulate(x, shift=None, scale=None): return x + shift.unsqueeze(1) else: return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - + def reshape_for_broadcast( freqs_cis, @@ -344,7 +349,7 @@ def rotate_half(x): x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) ) # [B, S, H, D//2] return torch.stack([-x_imag, x_real], dim=-1).flatten(3) - + def apply_rotary_emb( xq: torch.Tensor, @@ -386,6 +391,15 @@ def attention(q, k, v): return x +def apply_gate(x, gate, tr_gate=None, tr_token=None): + if tr_gate is not None: + x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1) + x_orig = x[:, tr_token:] * gate.unsqueeze(1) + return torch.concat((x_zero, x_orig), dim=1) + else: + return x * gate.unsqueeze(1) + + class MMDoubleStreamBlockComponent(torch.nn.Module): def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4): super().__init__() @@ -406,11 +420,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module): torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size) ) - def forward(self, hidden_states, conditioning, freqs_cis=None): + def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None): mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1) + if token_replace_vec is not None: + assert tr_token is not None + tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1) + else: + tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None norm_hidden_states = self.norm1(hidden_states) - norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale) + norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale, + tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token) qkv = self.to_qkv(norm_hidden_states) q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) @@ -419,15 +439,19 @@ class MMDoubleStreamBlockComponent(torch.nn.Module): if freqs_cis is not None: q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False) + return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate) - return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate) - - def process_ff(self, hidden_states, attn_output, mod): + def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None): mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod - hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1) - hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1) + if mod_tr is not None: + tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr + else: + tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None + hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token) + x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token)) + hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token) return hidden_states - + class MMDoubleStreamBlock(torch.nn.Module): def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4): @@ -435,18 +459,18 @@ class MMDoubleStreamBlock(torch.nn.Module): self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio) self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio) - def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis): - (q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis) - (q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None) + def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71): + (q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token) + (q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None) - q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous() - k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous() - v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous() + q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous() + k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous() + v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous() attn_output_a = attention(q_a, k_a, v_a) attn_output_b = attention(q_b, k_b, v_b) - attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1) + attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1) - hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a) + hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token) hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b) return hidden_states_a, hidden_states_b @@ -489,7 +513,7 @@ class MMSingleStreamBlockOriginal(torch.nn.Module): output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2)) return x + output * mod_gate.unsqueeze(1) - + class MMSingleStreamBlock(torch.nn.Module): def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4): @@ -510,11 +534,17 @@ class MMSingleStreamBlock(torch.nn.Module): torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False) ) - def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256): + def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71): mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1) + if token_replace_vec is not None: + assert tr_token is not None + tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1) + else: + tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None norm_hidden_states = self.norm(hidden_states) - norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale) + norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale, + tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token) qkv = self.to_qkv(norm_hidden_states) q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) @@ -526,16 +556,17 @@ class MMSingleStreamBlock(torch.nn.Module): k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False) - q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous() - k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous() - v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous() + v_len = txt_len - split_token + q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous() + k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous() + v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous() attn_output_a = attention(q_a, k_a, v_a) attn_output_b = attention(q_b, k_b, v_b) attn_output = torch.concat([attn_output_a, attn_output_b], dim=1) - hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1) - hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1) + hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token) + hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token) return hidden_states @@ -581,7 +612,7 @@ class HunyuanVideoDiT(torch.nn.Module): def unpatchify(self, x, T, H, W): x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2) return x - + def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"): self.warm_device = warm_device self.cold_device = cold_device @@ -616,7 +647,7 @@ class HunyuanVideoDiT(torch.nn.Module): vec += self.guidance_in(guidance * 1000, dtype=torch.float32) img = self.img_in(x) txt = self.txt_in(prompt_emb, t, text_mask) - + for block in tqdm(self.double_blocks, desc="Double stream blocks"): img, txt = block(img, txt, vec, (freqs_cos, freqs_sin)) @@ -628,7 +659,7 @@ class HunyuanVideoDiT(torch.nn.Module): img = self.final_layer(img, vec) img = self.unpatchify(img, T=T//1, H=H//2, W=W//2) return img - + def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"): def cast_to(weight, dtype=None, device=None, copy=False): @@ -684,7 +715,7 @@ class HunyuanVideoDiT(torch.nn.Module): del x_, weight_, bias_ torch.cuda.empty_cache() return y_ - + def block_forward(self, x, **kwargs): # This feature can only reduce 2GB VRAM, so we disable it. y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device) @@ -692,19 +723,19 @@ class HunyuanVideoDiT(torch.nn.Module): for j in range((self.out_features + self.block_size - 1) // self.block_size): y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device) return y - + def forward(self, x, **kwargs): weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device) return torch.nn.functional.linear(x, weight, bias) - + class RMSNorm(torch.nn.Module): def __init__(self, module, dtype=torch.bfloat16, device="cuda"): super().__init__() self.module = module self.dtype = dtype self.device = device - + def forward(self, hidden_states, **kwargs): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) @@ -714,30 +745,30 @@ class HunyuanVideoDiT(torch.nn.Module): weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda") hidden_states = hidden_states * weight return hidden_states - + class Conv3d(torch.nn.Conv3d): def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs): super().__init__(*args, **kwargs) self.dtype = dtype self.device = device - + def forward(self, x): weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device) return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups) - + class LayerNorm(torch.nn.LayerNorm): def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs): super().__init__(*args, **kwargs) self.dtype = dtype self.device = device - + def forward(self, x): if self.weight is not None and self.bias is not None: weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device) return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps) else: return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - + def replace_layer(model, dtype=torch.bfloat16, device="cuda"): for name, module in model.named_children(): if isinstance(module, torch.nn.Linear): @@ -780,7 +811,6 @@ class HunyuanVideoDiT(torch.nn.Module): return HunyuanVideoDiTStateDictConverter() - class HunyuanVideoDiTStateDictConverter: def __init__(self): pass @@ -886,6 +916,5 @@ class HunyuanVideoDiTStateDictConverter: state_dict_[name_] = param else: pass - if origin_hash_key == "ae3c22aaa28bfae6f3688f796c9814ae": - return state_dict_, {"in_channels": 33, "guidance_embed":False} + return state_dict_ diff --git a/diffsynth/pipelines/hunyuan_video.py b/diffsynth/pipelines/hunyuan_video.py index 6fbbae9..d8a0411 100644 --- a/diffsynth/pipelines/hunyuan_video.py +++ b/diffsynth/pipelines/hunyuan_video.py @@ -5,13 +5,13 @@ from ..schedulers.flow_match import FlowMatchScheduler from .base import BasePipeline from ..prompters import HunyuanVideoPrompter import torch +import torchvision.transforms as transforms from einops import rearrange import numpy as np from PIL import Image from tqdm import tqdm - class HunyuanVideoPipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.float16): @@ -53,10 +53,58 @@ class HunyuanVideoPipeline(BasePipeline): pipe.enable_vram_management() return pipe + def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0): + num_patches = round((base_size / patch_size)**2) + assert max_ratio >= 1.0 + crop_size_list = [] + wp, hp = num_patches, 1 + while wp > 0: + if max(wp, hp) / min(wp, hp) <= max_ratio: + crop_size_list.append((wp * patch_size, hp * patch_size)) + if (hp + 1) * wp <= num_patches: + hp += 1 + else: + wp -= 1 + return crop_size_list - def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256): + + def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list): + aspect_ratio = float(height) / float(width) + closest_ratio_id = np.abs(ratios - aspect_ratio).argmin() + closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio)) + return buckets[closest_ratio_id], float(closest_ratio) + + + def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"): + if i2v_resolution == "720p": + bucket_hw_base_size = 960 + elif i2v_resolution == "540p": + bucket_hw_base_size = 720 + elif i2v_resolution == "360p": + bucket_hw_base_size = 480 + else: + raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]") + origin_size = semantic_images[0].size + + crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32) + aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list]) + closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list) + ref_image_transform = transforms.Compose([ + transforms.Resize(closest_size), + transforms.CenterCrop(closest_size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]) + ]) + + semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images] + semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device) + target_height, target_width = closest_size + return semantic_image_pixel_values, target_height, target_width + + + def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None): prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt( - prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length + prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images ) return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask} @@ -87,6 +135,9 @@ class HunyuanVideoPipeline(BasePipeline): prompt, negative_prompt="", input_video=None, + input_images=None, + i2v_resolution="720p", + i2v_stability=True, denoising_strength=1.0, seed=None, rand_device=None, @@ -105,10 +156,17 @@ class HunyuanVideoPipeline(BasePipeline): ): # Tiler parameters tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride} - + # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + # encoder input images + if input_images is not None: + self.load_models_to_device(['vae_encoder']) + image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution) + with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True): + image_latents = self.vae_encoder(image_pixel_values) + # Initialize noise rand_device = self.device if rand_device is None else rand_device noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device) @@ -118,12 +176,18 @@ class HunyuanVideoPipeline(BasePipeline): input_video = torch.stack(input_video, dim=2) latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + elif input_images is not None and i2v_stability: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device) + t = torch.tensor([0.999]).to(device=self.device) + latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t) + latents = latents.to(dtype=image_latents.dtype) else: latents = noise - + # Encode prompts - self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"]) - prompt_emb_posi = self.encode_prompt(prompt, positive=True) + # current mllm does not support vram_management + self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"]) + prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images) if cfg_scale != 1.0: prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) @@ -139,11 +203,16 @@ class HunyuanVideoPipeline(BasePipeline): timestep = timestep.unsqueeze(0).to(self.device) print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}") + forward_func = lets_dance_hunyuan_video + if input_images is not None: + latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2) + forward_func = lets_dance_hunyuan_video_i2v + # Inference with torch.autocast(device_type=self.device, dtype=self.torch_dtype): - noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs) + noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs) if cfg_scale != 1.0: - noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input) + noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi @@ -163,7 +232,11 @@ class HunyuanVideoPipeline(BasePipeline): self.load_models_to_device([] if self.vram_management else ["dit"]) # Scheduler - latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) + if input_images is not None: + latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :]) + latents = torch.concat([image_latents, latents], dim=2) + else: + latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) # Decode self.load_models_to_device(['vae_decoder']) @@ -194,7 +267,7 @@ class TeaCache: if self.step == 0 or self.step == self.num_inference_steps - 1: should_calc = True self.accumulated_rel_l1_distance = 0 - else: + else: coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] rescale_func = np.poly1d(coefficients) self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) @@ -203,14 +276,14 @@ class TeaCache: else: should_calc = True self.accumulated_rel_l1_distance = 0 - self.previous_modulated_input = modulated_inp + self.previous_modulated_input = modulated_inp self.step += 1 if self.step == self.num_inference_steps: self.step = 0 if should_calc: self.previous_hidden_states = img.clone() return not should_calc - + def store(self, hidden_states): self.previous_residual = hidden_states - self.previous_hidden_states self.previous_hidden_states = None @@ -250,13 +323,70 @@ def lets_dance_hunyuan_video( print("TeaCache skip forward.") img = tea_cache.update(img) else: + split_token = int(text_mask.sum(dim=1)) + txt_len = int(txt.shape[1]) for block in tqdm(dit.double_blocks, desc="Double stream blocks"): - img, txt = block(img, txt, vec, (freqs_cos, freqs_sin)) - + img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token) + x = torch.concat([img, txt], dim=1) for block in tqdm(dit.single_blocks, desc="Single stream blocks"): - x = block(x, vec, (freqs_cos, freqs_sin)) - img = x[:, :-256] + x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token) + img = x[:, :-txt_len] + + if tea_cache is not None: + tea_cache.store(img) + img = dit.final_layer(img, vec) + img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2) + return img + + +def lets_dance_hunyuan_video_i2v( + dit: HunyuanVideoDiT, + x: torch.Tensor, + t: torch.Tensor, + prompt_emb: torch.Tensor = None, + text_mask: torch.Tensor = None, + pooled_prompt_emb: torch.Tensor = None, + freqs_cos: torch.Tensor = None, + freqs_sin: torch.Tensor = None, + guidance: torch.Tensor = None, + tea_cache: TeaCache = None, + **kwargs +): + B, C, T, H, W = x.shape + # Uncomment below to keep same as official implementation + # guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16) + vec = dit.time_in(t, dtype=torch.bfloat16) + vec_2 = dit.vector_in(pooled_prompt_emb) + vec = vec + vec_2 + vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16) + + token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16) + tr_token = (H // 2) * (W // 2) + token_replace_vec = token_replace_vec + vec_2 + + img = dit.img_in(x) + txt = dit.txt_in(prompt_emb, t, text_mask) + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, img, vec) + else: + tea_cache_update = False + + if tea_cache_update: + print("TeaCache skip forward.") + img = tea_cache.update(img) + else: + split_token = int(text_mask.sum(dim=1)) + txt_len = int(txt.shape[1]) + for block in tqdm(dit.double_blocks, desc="Double stream blocks"): + img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token) + + x = torch.concat([img, txt], dim=1) + for block in tqdm(dit.single_blocks, desc="Single stream blocks"): + x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token) + img = x[:, :-txt_len] if tea_cache is not None: tea_cache.store(img) diff --git a/diffsynth/prompters/hunyuan_video_prompter.py b/diffsynth/prompters/hunyuan_video_prompter.py index 26dc5c3..5b97356 100644 --- a/diffsynth/prompters/hunyuan_video_prompter.py +++ b/diffsynth/prompters/hunyuan_video_prompter.py @@ -87,7 +87,6 @@ class HunyuanVideoPrompter(BasePrompter): self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right') self.text_encoder_1: SD3TextEncoder1 = None self.text_encoder_2: HunyuanVideoLLMEncoder = None - self.i2v_mode = False self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode'] self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video'] @@ -106,8 +105,6 @@ class HunyuanVideoPrompter(BasePrompter): # template self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v'] self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v'] - # mode setting - self.i2v_mode = True def apply_text_to_template(self, text, template): assert isinstance(template, str) @@ -164,10 +161,8 @@ class HunyuanVideoPrompter(BasePrompter): crop_start, hidden_state_skip_layer=2, use_attention_mask=True, - image_embed_interleave=2): - image_outputs = self.processor(images, return_tensors="pt")[ - "pixel_values" - ].to(device) + image_embed_interleave=4): + image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device) max_length += crop_start inputs = self.tokenizer_2(prompt, return_tensors="pt", @@ -248,7 +243,8 @@ class HunyuanVideoPrompter(BasePrompter): data_type='video', use_template=True, hidden_state_skip_layer=2, - use_attention_mask=True): + use_attention_mask=True, + image_embed_interleave=4): prompt = self.process_prompt(prompt, positive=positive) @@ -273,6 +269,7 @@ class HunyuanVideoPrompter(BasePrompter): hidden_state_skip_layer, use_attention_mask) else: prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device, - crop_start, hidden_state_skip_layer, use_attention_mask) + crop_start, hidden_state_skip_layer, use_attention_mask, + image_embed_interleave) return prompt_emb, pooled_prompt_emb, attention_mask diff --git a/examples/HunyuanVideo/README.md b/examples/HunyuanVideo/README.md index 113f131..c1359b5 100644 --- a/examples/HunyuanVideo/README.md +++ b/examples/HunyuanVideo/README.md @@ -8,6 +8,12 @@ |24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)| |6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.| +[HunyuanVideo-I2V](https://github.com/Tencent/HunyuanVideo-I2V) is the image-to-video generation version of HunyuanVideo. We also provide advanced VRAM management for this model. +|VRAM required|Example script|Frames|Resolution|Note| +|-|-|-|-|-| +|80G|[hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py)|129|720p|No VRAM management.| +|24G|[hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py)|129|720p|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)| + ## Gallery Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py): @@ -21,3 +27,7 @@ https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817 Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video): https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10 + +Video generated by [hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py) and [hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py): + +https://github.com/user-attachments/assets/494f252a-c9af-440d-84ba-a8ddcdcc538a diff --git a/examples/HunyuanVideo/hunyuanvideo_i2v.py b/examples/HunyuanVideo/hunyuanvideo_i2v.py deleted file mode 100644 index 26d28a1..0000000 --- a/examples/HunyuanVideo/hunyuanvideo_i2v.py +++ /dev/null @@ -1,88 +0,0 @@ -import torch -from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video -from diffsynth.prompters.hunyuan_video_prompter import HunyuanVideoPrompter -from PIL import Image -import numpy as np -import torchvision.transforms as transforms - - -def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0): - num_patches = round((base_size / patch_size)**2) - assert max_ratio >= 1.0 - crop_size_list = [] - wp, hp = num_patches, 1 - while wp > 0: - if max(wp, hp) / min(wp, hp) <= max_ratio: - crop_size_list.append((wp * patch_size, hp * patch_size)) - if (hp + 1) * wp <= num_patches: - hp += 1 - else: - wp -= 1 - return crop_size_list - - -def get_closest_ratio(height: float, width: float, ratios: list, buckets: list): - aspect_ratio = float(height) / float(width) - closest_ratio_id = np.abs(ratios - aspect_ratio).argmin() - closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio)) - return buckets[closest_ratio_id], float(closest_ratio) - - -def prepare_vae_inputs(semantic_images, i2v_resolution="720p"): - if i2v_resolution == "720p": - bucket_hw_base_size = 960 - elif i2v_resolution == "540p": - bucket_hw_base_size = 720 - elif i2v_resolution == "360p": - bucket_hw_base_size = 480 - else: - raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]") - origin_size = semantic_images[0].size - - crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32) - aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list]) - closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list) - ref_image_transform = transforms.Compose([ - transforms.Resize(closest_size), - transforms.CenterCrop(closest_size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]) - ]) - - semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images] - semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2) - return semantic_image_pixel_values - - -model_manager = ModelManager() - -# The other modules are loaded in float16. - -model_manager.load_models( - [ - "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt" - ], - torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization. - device="cuda" -) - -model_manager.load_models( - [ - "models/HunyuanVideo/text_encoder/model.safetensors", - "models/HunyuanVideoI2V/text_encoder_2", - 'models/HunyuanVideoI2V/vae/pytorch_model.pt' - - ], - torch_dtype=torch.float16, - device="cuda" -) -# The computation device is "cuda". -pipe = HunyuanVideoPipeline.from_model_manager( - model_manager, - torch_dtype=torch.bfloat16, - device="cuda", - enable_vram_management=False -) -# Although you have enough VRAM, we still recommend you to enable offload. -pipe.enable_cpu_offload() -print() \ No newline at end of file diff --git a/examples/HunyuanVideo/hunyuanvideo_i2v_24G.py b/examples/HunyuanVideo/hunyuanvideo_i2v_24G.py new file mode 100644 index 0000000..91ae0e4 --- /dev/null +++ b/examples/HunyuanVideo/hunyuanvideo_i2v_24G.py @@ -0,0 +1,43 @@ +import torch +from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video +from modelscope import dataset_snapshot_download +from PIL import Image + + +download_models(["HunyuanVideoI2V"]) +model_manager = ModelManager() + +# The DiT model is loaded in bfloat16. +model_manager.load_models( + [ + "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt" + ], + torch_dtype=torch.bfloat16, + device="cpu" +) + +# The other modules are loaded in float16. +model_manager.load_models( + [ + "models/HunyuanVideoI2V/text_encoder/model.safetensors", + "models/HunyuanVideoI2V/text_encoder_2", + 'models/HunyuanVideoI2V/vae/pytorch_model.pt' + ], + torch_dtype=torch.float16, + device="cpu" +) +# The computation device is "cuda". +pipe = HunyuanVideoPipeline.from_model_manager(model_manager, + torch_dtype=torch.bfloat16, + device="cuda", + enable_vram_management=True) + +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/hunyuanvideo/*") + +i2v_resolution = "720p" +prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick." +images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')] +video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution) +save_video(video, f"video_{i2v_resolution}_low_vram.mp4", fps=30, quality=6) diff --git a/examples/HunyuanVideo/hunyuanvideo_i2v_80G.py b/examples/HunyuanVideo/hunyuanvideo_i2v_80G.py new file mode 100644 index 0000000..fcc9f62 --- /dev/null +++ b/examples/HunyuanVideo/hunyuanvideo_i2v_80G.py @@ -0,0 +1,45 @@ +import torch +from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video +from modelscope import dataset_snapshot_download +from PIL import Image + + +download_models(["HunyuanVideoI2V"]) +model_manager = ModelManager() + +# The DiT model is loaded in bfloat16. +model_manager.load_models( + [ + "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt" + ], + torch_dtype=torch.bfloat16, + device="cuda" +) + +# The other modules are loaded in float16. +model_manager.load_models( + [ + "models/HunyuanVideoI2V/text_encoder/model.safetensors", + "models/HunyuanVideoI2V/text_encoder_2", + 'models/HunyuanVideoI2V/vae/pytorch_model.pt' + ], + torch_dtype=torch.float16, + device="cuda" +) +# The computation device is "cuda". +pipe = HunyuanVideoPipeline.from_model_manager(model_manager, + torch_dtype=torch.bfloat16, + device="cuda", + enable_vram_management=False) +# Although you have enough VRAM, we still recommend you to enable offload. +pipe.enable_cpu_offload() + +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/hunyuanvideo/*") + +i2v_resolution = "720p" +prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick." +images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')] +video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution) +save_video(video, f"video_{i2v_resolution}.mp4", fps=30, quality=6)