mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
hunyuanvideo_vae_decoder
This commit is contained in:
@@ -43,6 +43,8 @@ from ..models.cog_dit import CogDiT
|
|||||||
|
|
||||||
from ..models.omnigen import OmniGenTransformer
|
from ..models.omnigen import OmniGenTransformer
|
||||||
|
|
||||||
|
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||||
|
|
||||||
from ..extensions.RIFE import IFNet
|
from ..extensions.RIFE import IFNet
|
||||||
from ..extensions.ESRGAN import RRDBNet
|
from ..extensions.ESRGAN import RRDBNet
|
||||||
|
|
||||||
@@ -94,6 +96,7 @@ model_loader_configs = [
|
|||||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
||||||
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||||
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
||||||
|
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder"], [HunyuanVideoVAEDecoder], "diffusers"),
|
||||||
]
|
]
|
||||||
huggingface_model_loader_configs = [
|
huggingface_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
@@ -638,11 +641,12 @@ preset_models_on_modelscope = {
|
|||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
||||||
|
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae")
|
||||||
],
|
],
|
||||||
"load_path": [
|
"load_path": [
|
||||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||||
"models/HunyuanVideo/text_encoder_2",
|
"models/HunyuanVideo/text_encoder_2",
|
||||||
|
"models/HunyuanVideo/vae/pytorch_model.pt"
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|||||||
|
|
||||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||||
from .hunyuan_dit import HunyuanDiT
|
from .hunyuan_dit import HunyuanDiT
|
||||||
|
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||||
|
|
||||||
from .flux_dit import FluxDiT
|
from .flux_dit import FluxDiT
|
||||||
from .flux_text_encoder import FluxTextEncoder2
|
from .flux_text_encoder import FluxTextEncoder2
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
from ..models import ModelManager, SD3TextEncoder1
|
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
|
||||||
from .base import BasePipeline
|
from .base import BasePipeline
|
||||||
from ..prompters import HunyuanVideoPrompter
|
from ..prompters import HunyuanVideoPrompter
|
||||||
import torch
|
import torch
|
||||||
from transformers import LlamaModel
|
from transformers import LlamaModel
|
||||||
|
from einops import rearrange
|
||||||
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoPipeline(BasePipeline):
|
class HunyuanVideoPipeline(BasePipeline):
|
||||||
|
|
||||||
@@ -13,11 +17,13 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
self.prompter = HunyuanVideoPrompter()
|
self.prompter = HunyuanVideoPrompter()
|
||||||
self.text_encoder_1: SD3TextEncoder1 = None
|
self.text_encoder_1: SD3TextEncoder1 = None
|
||||||
self.text_encoder_2: LlamaModel = None
|
self.text_encoder_2: LlamaModel = None
|
||||||
|
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
||||||
|
self.model_names = ['text_encoder_1', 'text_encoder_2', 'vae_decoder']
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager):
|
def fetch_models(self, model_manager: ModelManager):
|
||||||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
||||||
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
|
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
|
||||||
|
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
|
||||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -31,11 +37,19 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
||||||
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
|
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(prompt,
|
||||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
|
device=self.device,
|
||||||
)
|
positive=positive,
|
||||||
|
clip_sequence_length=clip_sequence_length,
|
||||||
|
llm_sequence_length=llm_sequence_length)
|
||||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
|
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
|
||||||
|
|
||||||
|
def tensor2video(self, frames):
|
||||||
|
frames = rearrange(frames, "C T H W -> T H W C")
|
||||||
|
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
||||||
|
frames = [Image.fromarray(frame) for frame in frames]
|
||||||
|
return frames
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -45,7 +59,16 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
progress_bar_st=None,
|
progress_bar_st=None,
|
||||||
):
|
):
|
||||||
pass
|
# encode prompt
|
||||||
|
# prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||||
|
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
# test data
|
||||||
return prompt_emb_posi
|
latents = torch.load('latents.pt').to(device=self.device, dtype=self.torch_dtype) # torch.Size([1, 16, 33, 90, 160])
|
||||||
|
latents = latents[:, :, :2, :, :]
|
||||||
|
# Tiler parameters
|
||||||
|
tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64)
|
||||||
|
# decode
|
||||||
|
self.load_models_to_device(['vae_decoder'])
|
||||||
|
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
|
||||||
|
frames = self.tensor2video(frames[0])
|
||||||
|
return frames
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models
|
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -8,9 +8,11 @@ download_models(["HunyuanVideo"])
|
|||||||
# Load models
|
# Load models
|
||||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||||
model_manager.load_models([
|
model_manager.load_models([
|
||||||
|
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||||
"t2i_models/HunyuanVideo/text_encoder/model.safetensors",
|
"t2i_models/HunyuanVideo/text_encoder/model.safetensors",
|
||||||
"t2i_models/HunyuanVideo/text_encoder_2",
|
"t2i_models/HunyuanVideo/text_encoder_2",
|
||||||
])
|
])
|
||||||
pipe = HunyuanVideoPipeline.from_model_manager(model_manager)
|
pipe = HunyuanVideoPipeline.from_model_manager(model_manager)
|
||||||
prompt = 'A cat walks on the grass, realistic style.'
|
prompt = 'A cat walks on the grass, realistic style.'
|
||||||
pipe(prompt)
|
frames = pipe(prompt)
|
||||||
|
save_video(frames, 'test_video.mp4', fps=8, quality=5)
|
||||||
|
|||||||
Reference in New Issue
Block a user