From 263166768ee16b082b83770ace8d08dcaf8743c2 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 18 Dec 2024 11:14:57 +0800 Subject: [PATCH] hunyuanvideo_vae_decoder --- diffsynth/configs/model_config.py | 6 +++- diffsynth/models/model_manager.py | 1 + diffsynth/pipelines/hunyuan_video.py | 39 +++++++++++++++++++----- examples/video_synthesis/hunyuanvideo.py | 6 ++-- 4 files changed, 41 insertions(+), 11 deletions(-) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 85b34af..453755f 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -43,6 +43,8 @@ from ..models.cog_dit import CogDiT from ..models.omnigen import OmniGenTransformer +from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder + from ..extensions.RIFE import IFNet from ..extensions.ESRGAN import RRDBNet @@ -94,6 +96,7 @@ model_loader_configs = [ (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"), (None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"), (None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"), + (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder"], [HunyuanVideoVAEDecoder], "diffusers"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -638,11 +641,12 @@ preset_models_on_modelscope = { ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"), ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"), ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"), - + ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae") ], "load_path": [ "models/HunyuanVideo/text_encoder/model.safetensors", "models/HunyuanVideo/text_encoder_2", + "models/HunyuanVideo/vae/pytorch_model.pt" ], }, } diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index f8351d2..1fee906 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -35,6 +35,7 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder from .hunyuan_dit import HunyuanDiT +from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder from .flux_dit import FluxDiT from .flux_text_encoder import FluxTextEncoder2 diff --git a/diffsynth/pipelines/hunyuan_video.py b/diffsynth/pipelines/hunyuan_video.py index 9e26a8f..51aa06b 100644 --- a/diffsynth/pipelines/hunyuan_video.py +++ b/diffsynth/pipelines/hunyuan_video.py @@ -1,9 +1,13 @@ -from ..models import ModelManager, SD3TextEncoder1 +from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder from .base import BasePipeline from ..prompters import HunyuanVideoPrompter import torch from transformers import LlamaModel +from einops import rearrange +import numpy as np from tqdm import tqdm +from PIL import Image + class HunyuanVideoPipeline(BasePipeline): @@ -13,11 +17,13 @@ class HunyuanVideoPipeline(BasePipeline): self.prompter = HunyuanVideoPrompter() self.text_encoder_1: SD3TextEncoder1 = None self.text_encoder_2: LlamaModel = None - + self.vae_decoder: HunyuanVideoVAEDecoder = None + self.model_names = ['text_encoder_1', 'text_encoder_2', 'vae_decoder'] def fetch_models(self, model_manager: ModelManager): self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1") self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2") + self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder") self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2) @staticmethod @@ -31,11 +37,19 @@ class HunyuanVideoPipeline(BasePipeline): return pipe def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256): - prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt( - prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length - ) + prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(prompt, + device=self.device, + positive=positive, + clip_sequence_length=clip_sequence_length, + llm_sequence_length=llm_sequence_length) return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb} + def tensor2video(self, frames): + frames = rearrange(frames, "C T H W -> T H W C") + frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) + frames = [Image.fromarray(frame) for frame in frames] + return frames + @torch.no_grad() def __call__( self, @@ -45,7 +59,16 @@ class HunyuanVideoPipeline(BasePipeline): progress_bar_cmd=tqdm, progress_bar_st=None, ): - pass + # encode prompt + # prompt_emb_posi = self.encode_prompt(prompt, positive=True) - prompt_emb_posi = self.encode_prompt(prompt, positive=True) - return prompt_emb_posi \ No newline at end of file + # test data + latents = torch.load('latents.pt').to(device=self.device, dtype=self.torch_dtype) # torch.Size([1, 16, 33, 90, 160]) + latents = latents[:, :, :2, :, :] + # Tiler parameters + tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64) + # decode + self.load_models_to_device(['vae_decoder']) + frames = self.vae_decoder.decode_video(latents, **tiler_kwargs) + frames = self.tensor2video(frames[0]) + return frames diff --git a/examples/video_synthesis/hunyuanvideo.py b/examples/video_synthesis/hunyuanvideo.py index e4af344..57ef567 100644 --- a/examples/video_synthesis/hunyuanvideo.py +++ b/examples/video_synthesis/hunyuanvideo.py @@ -1,4 +1,4 @@ -from diffsynth import ModelManager, HunyuanVideoPipeline, download_models +from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video import torch @@ -8,9 +8,11 @@ download_models(["HunyuanVideo"]) # Load models model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") model_manager.load_models([ + "models/HunyuanVideo/vae/pytorch_model.pt", "t2i_models/HunyuanVideo/text_encoder/model.safetensors", "t2i_models/HunyuanVideo/text_encoder_2", ]) pipe = HunyuanVideoPipeline.from_model_manager(model_manager) prompt = 'A cat walks on the grass, realistic style.' -pipe(prompt) +frames = pipe(prompt) +save_video(frames, 'test_video.mp4', fps=8, quality=5)