From e5099f4e7446ed70ade8b7ee3c93a9a2f4cb40cd Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 18 Dec 2024 16:43:06 +0800 Subject: [PATCH] hunyuanvideo --- diffsynth/configs/model_config.py | 6 +- diffsynth/models/hunyuan_video_vae_decoder.py | 97 +++++++++++++++++-- diffsynth/models/lora.py | 11 ++- diffsynth/pipelines/hunyuan_video.py | 35 +++++-- examples/video_synthesis/hunyuanvideo.py | 50 +++++++--- 5 files changed, 168 insertions(+), 31 deletions(-) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 5c68827..b4f08ba 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -643,12 +643,14 @@ preset_models_on_modelscope = { ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"), ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"), ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"), - ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae") + ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"), + ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers") ], "load_path": [ "models/HunyuanVideo/text_encoder/model.safetensors", "models/HunyuanVideo/text_encoder_2", - "models/HunyuanVideo/vae/pytorch_model.pt" + "models/HunyuanVideo/vae/pytorch_model.pt", + "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt" ], }, } diff --git a/diffsynth/models/hunyuan_video_vae_decoder.py b/diffsynth/models/hunyuan_video_vae_decoder.py index 69d9d9b..700f10e 100644 --- a/diffsynth/models/hunyuan_video_vae_decoder.py +++ b/diffsynth/models/hunyuan_video_vae_decoder.py @@ -3,6 +3,8 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange import numpy as np +from tqdm import tqdm +from einops import repeat class CausalConv3d(nn.Module): @@ -393,16 +395,99 @@ class HunyuanVideoVAEDecoder(nn.Module): gradient_checkpointing=gradient_checkpointing, ) self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.scaling_factor = 0.476986 - def decode_video(self, latents, use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64): - if use_temporal_tiling: - raise NotImplementedError - if use_spatial_tiling: - raise NotImplementedError - # no tiling + + def forward(self, latents): + latents = latents / self.scaling_factor latents = self.post_quant_conv(latents) dec = self.decoder(latents) return dec + + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + + def build_mask(self, data, is_bound, border_width): + _, _, T, H, W = data.shape + t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) + h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1]) + w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2]) + + t = repeat(t, "T -> T H W", T=T, H=H, W=W) + h = repeat(h, "H -> T H W", T=T, H=H, W=W) + w = repeat(w, "W -> T H W", T=T, H=H, W=W) + + mask = torch.stack([t, h, w]).min(dim=0).values + mask = rearrange(mask, "T H W -> 1 1 T H W") + return mask + + + def tile_forward(self, hidden_states, tile_size, tile_stride): + B, C, T, H, W = hidden_states.shape + size_t, size_h, size_w = tile_size + stride_t, stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for t in range(0, T, stride_t): + if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + t_, h_, w_ = t + size_t, h + size_h, w + size_w + tasks.append((t, t_, h, h_, w, w_)) + + # Run + torch_dtype = self.post_quant_conv.weight.dtype + data_device = hidden_states.device + computation_device = self.post_quant_conv.weight.device + + weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device) + values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device) + + for t, t_, h, h_, w, w_ in tqdm(tasks): + hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.forward(hidden_states_batch).to(data_device) + if t > 0: + hidden_states_batch = hidden_states_batch[:, :, 1:] + + mask = self.build_mask( + hidden_states_batch, + is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W), + border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8) + ).to(dtype=torch_dtype, device=data_device) + + target_t = 0 if t==0 else t * 4 + 1 + target_h = h * 8 + target_w = w * 8 + values[ + :, + :, + target_t: target_t + hidden_states_batch.shape[2], + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + target_t: target_t + hidden_states_batch.shape[2], + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + return values / weight + + + def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)): + latents = latents.to(self.post_quant_conv.weight.dtype) + return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride) @staticmethod def state_dict_converter(): diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index d416864..33f952f 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -7,6 +7,7 @@ from .sd3_dit import SD3DiT from .flux_dit import FluxDiT from .hunyuan_dit import HunyuanDiT from .cog_dit import CogDiT +from .hunyuan_video_dit import HunyuanVideoDiT @@ -259,6 +260,14 @@ class GeneralLoRAFromPeft: return None +class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai): + def __init__(self): + super().__init__() + self.supported_model_classes = [HunyuanVideoDiT] + self.lora_prefix = ["diffusion_model."] + self.special_keys = {} + + class FluxLoRAConverter: def __init__(self): pass @@ -355,4 +364,4 @@ class FluxLoRAConverter: def get_lora_loaders(): - return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()] + return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()] diff --git a/diffsynth/pipelines/hunyuan_video.py b/diffsynth/pipelines/hunyuan_video.py index d44e0d8..7a8b297 100644 --- a/diffsynth/pipelines/hunyuan_video.py +++ b/diffsynth/pipelines/hunyuan_video.py @@ -7,10 +7,10 @@ import torch from transformers import LlamaModel from einops import rearrange import numpy as np -from tqdm import tqdm from PIL import Image + class HunyuanVideoPipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.float16): @@ -22,6 +22,13 @@ class HunyuanVideoPipeline(BasePipeline): self.dit: HunyuanVideoDiT = None self.vae_decoder: HunyuanVideoVAEDecoder = None self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder'] + self.vram_management = False + + + def enable_vram_management(self): + self.vram_management = True + self.enable_cpu_offload() + self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device) def fetch_models(self, model_manager: ModelManager): @@ -38,10 +45,8 @@ class HunyuanVideoPipeline(BasePipeline): if torch_dtype is None: torch_dtype = model_manager.torch_dtype pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype) pipe.fetch_models(model_manager) - # VRAM management is automatically enabled. if enable_vram_management: - pipe.enable_cpu_offload() - pipe.dit.enable_auto_offload(dtype=torch_dtype, device=device) + pipe.enable_vram_management() return pipe @@ -77,26 +82,34 @@ class HunyuanVideoPipeline(BasePipeline): embedded_guidance=6.0, cfg_scale=1.0, num_inference_steps=30, + tile_size=(17, 30, 30), + tile_stride=(12, 20, 20), progress_bar_cmd=lambda x: x, progress_bar_st=None, ): + # Initialize noise latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + # Encode prompts self.load_models_to_device(["text_encoder_1", "text_encoder_2"]) prompt_emb_posi = self.encode_prompt(prompt, positive=True) if cfg_scale != 1.0: prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) + # Extra input extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) + # Scheduler self.scheduler.set_timesteps(num_inference_steps) - self.load_models_to_device([]) + # Denoise + self.load_models_to_device([] if self.vram_management else ["dit"]) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) - with torch.autocast(device_type=self.device, dtype=self.torch_dtype): - print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}") + print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}") + # Inference + with torch.autocast(device_type=self.device, dtype=self.torch_dtype): noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input) if cfg_scale != 1.0: noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input) @@ -104,12 +117,16 @@ class HunyuanVideoPipeline(BasePipeline): else: noise_pred = noise_pred_posi + # Scheduler latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) # Tiler parameters - tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64) - # decode + tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride} + + # Decode self.load_models_to_device(['vae_decoder']) frames = self.vae_decoder.decode_video(latents, **tiler_kwargs) + self.load_models_to_device([]) frames = self.tensor2video(frames[0]) + return frames diff --git a/examples/video_synthesis/hunyuanvideo.py b/examples/video_synthesis/hunyuanvideo.py index 57ef567..0b4c8bb 100644 --- a/examples/video_synthesis/hunyuanvideo.py +++ b/examples/video_synthesis/hunyuanvideo.py @@ -1,18 +1,42 @@ -from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video import torch +torch.cuda.set_per_process_memory_fraction(1.0, 0) +from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video -# Download models (automatically) download_models(["HunyuanVideo"]) +model_manager = ModelManager() -# Load models -model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") -model_manager.load_models([ - "models/HunyuanVideo/vae/pytorch_model.pt", - "t2i_models/HunyuanVideo/text_encoder/model.safetensors", - "t2i_models/HunyuanVideo/text_encoder_2", -]) -pipe = HunyuanVideoPipeline.from_model_manager(model_manager) -prompt = 'A cat walks on the grass, realistic style.' -frames = pipe(prompt) -save_video(frames, 'test_video.mp4', fps=8, quality=5) +# The DiT model is loaded in bfloat16. +model_manager.load_models( + [ + "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt" + ], + torch_dtype=torch.bfloat16, + device="cpu" +) + +# The other modules are loaded in float16. +model_manager.load_models( + [ + "models/HunyuanVideo/text_encoder/model.safetensors", + "models/HunyuanVideo/text_encoder_2", + "models/HunyuanVideo/vae/pytorch_model.pt", + ], + torch_dtype=torch.float16, + device="cpu" +) + +# We support LoRA inference. You can use the following code to load your LoRA model. +# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0) + +# The computation device is "cuda". +pipe = HunyuanVideoPipeline.from_model_manager( + model_manager, + torch_dtype=torch.bfloat16, + device="cuda" +) + +# Enjoy! +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +video = pipe(prompt, seed=0) +save_video(video, "video.mp4", fps=30, quality=5)