hunyuanvideo_vae_decoder

This commit is contained in:
mi804
2024-12-18 11:14:57 +08:00
parent 5d1005a7c8
commit 263166768e
4 changed files with 41 additions and 11 deletions

View File

@@ -43,6 +43,8 @@ from ..models.cog_dit import CogDiT
from ..models.omnigen import OmniGenTransformer
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
@@ -94,6 +96,7 @@ model_loader_configs = [
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder"], [HunyuanVideoVAEDecoder], "diffusers"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -638,11 +641,12 @@ preset_models_on_modelscope = {
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae")
],
"load_path": [
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideo/text_encoder_2",
"models/HunyuanVideo/vae/pytorch_model.pt"
],
},
}

View File

@@ -35,6 +35,7 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder2

View File

@@ -1,9 +1,13 @@
from ..models import ModelManager, SD3TextEncoder1
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter
import torch
from transformers import LlamaModel
from einops import rearrange
import numpy as np
from tqdm import tqdm
from PIL import Image
class HunyuanVideoPipeline(BasePipeline):
@@ -13,11 +17,13 @@ class HunyuanVideoPipeline(BasePipeline):
self.prompter = HunyuanVideoPrompter()
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: LlamaModel = None
self.vae_decoder: HunyuanVideoVAEDecoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'vae_decoder']
def fetch_models(self, model_manager: ModelManager):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
@staticmethod
@@ -31,11 +37,19 @@ class HunyuanVideoPipeline(BasePipeline):
return pipe
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
)
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(prompt,
device=self.device,
positive=positive,
clip_sequence_length=clip_sequence_length,
llm_sequence_length=llm_sequence_length)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
@torch.no_grad()
def __call__(
self,
@@ -45,7 +59,16 @@ class HunyuanVideoPipeline(BasePipeline):
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
pass
# encode prompt
# prompt_emb_posi = self.encode_prompt(prompt, positive=True)
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
return prompt_emb_posi
# test data
latents = torch.load('latents.pt').to(device=self.device, dtype=self.torch_dtype) # torch.Size([1, 16, 33, 90, 160])
latents = latents[:, :, :2, :, :]
# Tiler parameters
tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64)
# decode
self.load_models_to_device(['vae_decoder'])
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
frames = self.tensor2video(frames[0])
return frames