From 405ca6be33b4bf938695302d3a5159aa109d47fd Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 23 Dec 2024 20:43:47 +0800 Subject: [PATCH] support hunyuanvideo v2v --- diffsynth/models/hunyuan_video_vae_decoder.py | 2 +- diffsynth/models/hunyuan_video_vae_encoder.py | 94 ++++++++++++++++++- diffsynth/models/lora.py | 4 +- diffsynth/pipelines/hunyuan_video.py | 35 ++++--- examples/HunyuanVideo/README.md | 4 + examples/HunyuanVideo/hunyuanvideo_v2v_6G.py | 50 ++++++++++ 6 files changed, 173 insertions(+), 16 deletions(-) create mode 100644 examples/HunyuanVideo/hunyuanvideo_v2v_6G.py diff --git a/diffsynth/models/hunyuan_video_vae_decoder.py b/diffsynth/models/hunyuan_video_vae_decoder.py index 700f10e..ae09ff8 100644 --- a/diffsynth/models/hunyuan_video_vae_decoder.py +++ b/diffsynth/models/hunyuan_video_vae_decoder.py @@ -453,7 +453,7 @@ class HunyuanVideoVAEDecoder(nn.Module): weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device) values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device) - for t, t_, h, h_, w, w_ in tqdm(tasks): + for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"): hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device) hidden_states_batch = self.forward(hidden_states_batch).to(data_device) if t > 0: diff --git a/diffsynth/models/hunyuan_video_vae_encoder.py b/diffsynth/models/hunyuan_video_vae_encoder.py index ec7fd14..faaaeb9 100644 --- a/diffsynth/models/hunyuan_video_vae_encoder.py +++ b/diffsynth/models/hunyuan_video_vae_encoder.py @@ -1,8 +1,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange +from einops import rearrange, repeat import numpy as np +from tqdm import tqdm from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D @@ -192,12 +193,101 @@ class HunyuanVideoVAEEncoder(nn.Module): gradient_checkpointing=gradient_checkpointing, ) self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1) + self.scaling_factor = 0.476986 + def forward(self, images): latents = self.encoder(images) latents = self.quant_conv(latents) - # latents: (B C T H W) + latents = latents[:, :16] + latents = latents * self.scaling_factor return latents + + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + + def build_mask(self, data, is_bound, border_width): + _, _, T, H, W = data.shape + t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) + h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1]) + w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2]) + + t = repeat(t, "T -> T H W", T=T, H=H, W=W) + h = repeat(h, "H -> T H W", T=T, H=H, W=W) + w = repeat(w, "W -> T H W", T=T, H=H, W=W) + + mask = torch.stack([t, h, w]).min(dim=0).values + mask = rearrange(mask, "T H W -> 1 1 T H W") + return mask + + + def tile_forward(self, hidden_states, tile_size, tile_stride): + B, C, T, H, W = hidden_states.shape + size_t, size_h, size_w = tile_size + stride_t, stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for t in range(0, T, stride_t): + if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + t_, h_, w_ = t + size_t, h + size_h, w + size_w + tasks.append((t, t_, h, h_, w, w_)) + + # Run + torch_dtype = self.quant_conv.weight.dtype + data_device = hidden_states.device + computation_device = self.quant_conv.weight.device + + weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device) + values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device) + + for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): + hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.forward(hidden_states_batch).to(data_device) + if t > 0: + hidden_states_batch = hidden_states_batch[:, :, 1:] + + mask = self.build_mask( + hidden_states_batch, + is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W), + border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8) + ).to(dtype=torch_dtype, device=data_device) + + target_t = 0 if t==0 else t // 4 + 1 + target_h = h // 8 + target_w = w // 8 + values[ + :, + :, + target_t: target_t + hidden_states_batch.shape[2], + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + target_t: target_t + hidden_states_batch.shape[2], + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + return values / weight + + + def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)): + latents = latents.to(self.quant_conv.weight.dtype) + return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride) + @staticmethod def state_dict_converter(): diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index 33f952f..81a9a8a 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -263,8 +263,8 @@ class GeneralLoRAFromPeft: class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai): def __init__(self): super().__init__() - self.supported_model_classes = [HunyuanVideoDiT] - self.lora_prefix = ["diffusion_model."] + self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT] + self.lora_prefix = ["diffusion_model.", "transformer."] self.special_keys = {} diff --git a/diffsynth/pipelines/hunyuan_video.py b/diffsynth/pipelines/hunyuan_video.py index 53e527f..ad81ca4 100644 --- a/diffsynth/pipelines/hunyuan_video.py +++ b/diffsynth/pipelines/hunyuan_video.py @@ -72,16 +72,21 @@ class HunyuanVideoPipeline(BasePipeline): frames = [Image.fromarray(frame) for frame in frames] return frames - def encode_video(self, frames): - # frames : (B, C, T, H, W) - latents = self.vae_encoder(frames) + + def encode_video(self, frames, tile_size=(17, 30, 30), tile_stride=(12, 20, 20)): + tile_size = ((tile_size[0] - 1) * 4 + 1, tile_size[1] * 8, tile_size[2] * 8) + tile_stride = (tile_stride[0] * 4, tile_stride[1] * 8, tile_stride[2] * 8) + latents = self.vae_encoder.encode_video(frames, tile_size=tile_size, tile_stride=tile_stride) return latents - + + @torch.no_grad() def __call__( self, prompt, negative_prompt="", + input_video=None, + denoising_strength=1.0, seed=None, height=720, width=1280, @@ -94,8 +99,22 @@ class HunyuanVideoPipeline(BasePipeline): progress_bar_cmd=lambda x: x, progress_bar_st=None, ): + # Tiler parameters + tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride} + + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + # Initialize noise - latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + if input_video is not None: + self.load_models_to_device(['vae_encoder']) + input_video = self.preprocess_images(input_video) + input_video = torch.stack(input_video, dim=2) + latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device) + latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + else: + latents = noise # Encode prompts self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"]) @@ -106,9 +125,6 @@ class HunyuanVideoPipeline(BasePipeline): # Extra input extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) - # Scheduler - self.scheduler.set_timesteps(num_inference_steps) - # Denoise self.load_models_to_device([] if self.vram_management else ["dit"]) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): @@ -126,9 +142,6 @@ class HunyuanVideoPipeline(BasePipeline): # Scheduler latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) - - # Tiler parameters - tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride} # Decode self.load_models_to_device(['vae_decoder']) diff --git a/examples/HunyuanVideo/README.md b/examples/HunyuanVideo/README.md index f26b280..113f131 100644 --- a/examples/HunyuanVideo/README.md +++ b/examples/HunyuanVideo/README.md @@ -17,3 +17,7 @@ https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9 Video generated by [hunyuanvideo_6G.py](hunyuanvideo_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video): https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817 + +Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video): + +https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10 diff --git a/examples/HunyuanVideo/hunyuanvideo_v2v_6G.py b/examples/HunyuanVideo/hunyuanvideo_v2v_6G.py new file mode 100644 index 0000000..50c68f1 --- /dev/null +++ b/examples/HunyuanVideo/hunyuanvideo_v2v_6G.py @@ -0,0 +1,50 @@ +import torch +torch.cuda.set_per_process_memory_fraction(6/80, 0) +from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video, FlowMatchScheduler + + +download_models(["HunyuanVideo"]) +model_manager = ModelManager() + +# The DiT model is loaded in bfloat16. +model_manager.load_models( + [ + "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt" + ], + torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization. + device="cpu" +) + +# The other modules are loaded in float16. +model_manager.load_models( + [ + "models/HunyuanVideo/text_encoder/model.safetensors", + "models/HunyuanVideo/text_encoder_2", + "models/HunyuanVideo/vae/pytorch_model.pt", + ], + torch_dtype=torch.float16, + device="cpu" +) + +# We support LoRA inference. You can use the following code to load your LoRA model. +# Example LoRA: https://civitai.com/models/1032126/walking-animation-hunyuan-video +model_manager.load_lora("models/lora/kxsr_walking_anim_v1-5.safetensors", lora_alpha=1.0) + +# The computation device is "cuda". +pipe = HunyuanVideoPipeline.from_model_manager( + model_manager, + torch_dtype=torch.bfloat16, + device="cuda" +) +# This LoRA requires shift=9.0. +pipe.scheduler = FlowMatchScheduler(shift=9.0, sigma_min=0.0, extra_one_step=True) + +# Text-to-video +prompt = f"kxsr, full body, no crop. A girl is walking. CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +video = pipe(prompt, seed=1, height=512, width=384, num_frames=129, num_inference_steps=18, tile_size=(17, 16, 16), tile_stride=(12, 12, 12)) +save_video(video, f"video.mp4", fps=30, quality=6) + +# Video-to-video +prompt = f"kxsr, full body, no crop. A girl is walking. CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, purple dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +video = pipe(prompt, seed=1, height=512, width=384, num_frames=129, num_inference_steps=18, tile_size=(17, 16, 16), tile_stride=(12, 12, 12), input_video=video, denoising_strength=0.85) +save_video(video, f"video_edited.mp4", fps=30, quality=6)