Files
DiffSynth-Studio/diffsynth/pipelines/hunyuan_video.py
2024-12-18 11:14:57 +08:00

75 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter
import torch
from transformers import LlamaModel
from einops import rearrange
import numpy as np
from tqdm import tqdm
from PIL import Image
class HunyuanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
# 参照diffsynth的排序text_encoder_1指CLIPtext_encoder_2指llm与hunyuanvideo源代码刚好相反
self.prompter = HunyuanVideoPrompter()
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: LlamaModel = None
self.vae_decoder: HunyuanVideoVAEDecoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'vae_decoder']
def fetch_models(self, model_manager: ModelManager):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
@staticmethod
def from_model_manager(model_manager: ModelManager, device=None):
pipe = HunyuanVideoPipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager)
return pipe
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(prompt,
device=self.device,
positive=positive,
clip_sequence_length=clip_sequence_length,
llm_sequence_length=llm_sequence_length)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# encode prompt
# prompt_emb_posi = self.encode_prompt(prompt, positive=True)
# test data
latents = torch.load('latents.pt').to(device=self.device, dtype=self.torch_dtype) # torch.Size([1, 16, 33, 90, 160])
latents = latents[:, :, :2, :, :]
# Tiler parameters
tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64)
# decode
self.load_models_to_device(['vae_decoder'])
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
frames = self.tensor2video(frames[0])
return frames