Merge branch 'dev' into dev-dzj

This commit is contained in:
Zhongjie Duan
2024-12-18 11:47:34 +08:00
committed by GitHub
5 changed files with 455 additions and 9 deletions

View File

@@ -1,10 +1,14 @@
from ..models import ModelManager, SD3TextEncoder1
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
from ..models.hunyuan_video_dit import HunyuanVideoDiT
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter
import torch
from transformers import LlamaModel
from einops import rearrange
import numpy as np
from tqdm import tqdm
from PIL import Image
class HunyuanVideoPipeline(BasePipeline):
@@ -16,13 +20,15 @@ class HunyuanVideoPipeline(BasePipeline):
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: LlamaModel = None
self.dit: HunyuanVideoDiT = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit']
self.vae_decoder: HunyuanVideoVAEDecoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
def fetch_models(self, model_manager: ModelManager):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
self.dit = model_manager.fetch_model("hunyuan_video_dit")
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
@@ -52,6 +58,13 @@ class HunyuanVideoPipeline(BasePipeline):
return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
@torch.no_grad()
def __call__(
self,
@@ -93,6 +106,10 @@ class HunyuanVideoPipeline(BasePipeline):
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# TODO: Add VAE decode here.
return latents
# Tiler parameters
tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64)
# decode
self.load_models_to_device(['vae_decoder'])
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
frames = self.tensor2video(frames[0])
return frames