mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
Merge branch 'dev' into dev-dzj
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
from ..models import ModelManager, SD3TextEncoder1
|
||||
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
|
||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import HunyuanVideoPrompter
|
||||
import torch
|
||||
from transformers import LlamaModel
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class HunyuanVideoPipeline(BasePipeline):
|
||||
@@ -16,13 +20,15 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: LlamaModel = None
|
||||
self.dit: HunyuanVideoDiT = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit']
|
||||
|
||||
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
||||
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
|
||||
self.dit = model_manager.fetch_model("hunyuan_video_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
|
||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||
|
||||
|
||||
@@ -52,6 +58,13 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
frames = rearrange(frames, "C T H W -> T H W C")
|
||||
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
return frames
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -93,6 +106,10 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# TODO: Add VAE decode here.
|
||||
|
||||
return latents
|
||||
# Tiler parameters
|
||||
tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64)
|
||||
# decode
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
|
||||
frames = self.tensor2video(frames[0])
|
||||
return frames
|
||||
|
||||
Reference in New Issue
Block a user