mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 08:40:47 +00:00
75 lines
3.2 KiB
Python
75 lines
3.2 KiB
Python
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
|
||
from .base import BasePipeline
|
||
from ..prompters import HunyuanVideoPrompter
|
||
import torch
|
||
from transformers import LlamaModel
|
||
from einops import rearrange
|
||
import numpy as np
|
||
from tqdm import tqdm
|
||
from PIL import Image
|
||
|
||
|
||
class HunyuanVideoPipeline(BasePipeline):
|
||
|
||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||
# 参照diffsynth的排序,text_encoder_1指CLIP;text_encoder_2指llm,与hunyuanvideo源代码刚好相反
|
||
self.prompter = HunyuanVideoPrompter()
|
||
self.text_encoder_1: SD3TextEncoder1 = None
|
||
self.text_encoder_2: LlamaModel = None
|
||
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'vae_decoder']
|
||
|
||
def fetch_models(self, model_manager: ModelManager):
|
||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
||
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
|
||
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
|
||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||
|
||
@staticmethod
|
||
def from_model_manager(model_manager: ModelManager, device=None):
|
||
|
||
pipe = HunyuanVideoPipeline(
|
||
device=model_manager.device if device is None else device,
|
||
torch_dtype=model_manager.torch_dtype,
|
||
)
|
||
pipe.fetch_models(model_manager)
|
||
return pipe
|
||
|
||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
||
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(prompt,
|
||
device=self.device,
|
||
positive=positive,
|
||
clip_sequence_length=clip_sequence_length,
|
||
llm_sequence_length=llm_sequence_length)
|
||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
|
||
|
||
def tensor2video(self, frames):
|
||
frames = rearrange(frames, "C T H W -> T H W C")
|
||
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
||
frames = [Image.fromarray(frame) for frame in frames]
|
||
return frames
|
||
|
||
@torch.no_grad()
|
||
def __call__(
|
||
self,
|
||
prompt,
|
||
negative_prompt="",
|
||
seed=None,
|
||
progress_bar_cmd=tqdm,
|
||
progress_bar_st=None,
|
||
):
|
||
# encode prompt
|
||
# prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||
|
||
# test data
|
||
latents = torch.load('latents.pt').to(device=self.device, dtype=self.torch_dtype) # torch.Size([1, 16, 33, 90, 160])
|
||
latents = latents[:, :, :2, :, :]
|
||
# Tiler parameters
|
||
tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64)
|
||
# decode
|
||
self.load_models_to_device(['vae_decoder'])
|
||
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
|
||
frames = self.tensor2video(frames[0])
|
||
return frames
|