Files
DiffSynth-Studio/diffsynth/pipelines/hunyuan_video.py
2024-12-18 11:47:34 +08:00

116 lines
5.1 KiB
Python

from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
from ..models.hunyuan_video_dit import HunyuanVideoDiT
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter
import torch
from transformers import LlamaModel
from einops import rearrange
import numpy as np
from tqdm import tqdm
from PIL import Image
class HunyuanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
self.prompter = HunyuanVideoPrompter()
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: LlamaModel = None
self.dit: HunyuanVideoDiT = None
self.vae_decoder: HunyuanVideoVAEDecoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
def fetch_models(self, model_manager: ModelManager):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
self.dit = model_manager.fetch_model("hunyuan_video_dit")
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, enable_vram_management=True):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# VRAM management is automatically enabled.
if enable_vram_management:
pipe.enable_cpu_offload()
pipe.dit.enable_auto_offload(dtype=torch_dtype, device=device)
return pipe
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
def prepare_extra_input(self, latents=None, guidance=1.0):
freqs_cos, freqs_sin = self.dit.prepare_freqs(latents)
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
seed=None,
height=720,
width=1280,
num_frames=129,
embedded_guidance=6.0,
cfg_scale=1.0,
num_inference_steps=30,
progress_bar_cmd=lambda x: x,
progress_bar_st=None,
):
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
self.scheduler.set_timesteps(num_inference_steps)
self.load_models_to_device([])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Tiler parameters
tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64)
# decode
self.load_models_to_device(['vae_decoder'])
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
frames = self.tensor2video(frames[0])
return frames