mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
hunyuanvideo
This commit is contained in:
@@ -643,12 +643,14 @@ preset_models_on_modelscope = {
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae")
|
||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
||||
],
|
||||
"load_path": [
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt"
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from einops import repeat
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
@@ -393,16 +395,99 @@ class HunyuanVideoVAEDecoder(nn.Module):
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.scaling_factor = 0.476986
|
||||
|
||||
def decode_video(self, latents, use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64):
|
||||
if use_temporal_tiling:
|
||||
raise NotImplementedError
|
||||
if use_spatial_tiling:
|
||||
raise NotImplementedError
|
||||
# no tiling
|
||||
|
||||
def forward(self, latents):
|
||||
latents = latents / self.scaling_factor
|
||||
latents = self.post_quant_conv(latents)
|
||||
dec = self.decoder(latents)
|
||||
return dec
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, T, H, W = data.shape
|
||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||
|
||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||
|
||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
size_t, size_h, size_w = tile_size
|
||||
stride_t, stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for t in range(0, T, stride_t):
|
||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||
tasks.append((t, t_, h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
torch_dtype = self.post_quant_conv.weight.dtype
|
||||
data_device = hidden_states.device
|
||||
computation_device = self.post_quant_conv.weight.device
|
||||
|
||||
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||
|
||||
for t, t_, h, h_, w, w_ in tqdm(tasks):
|
||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||
if t > 0:
|
||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
|
||||
).to(dtype=torch_dtype, device=data_device)
|
||||
|
||||
target_t = 0 if t==0 else t * 4 + 1
|
||||
target_h = h * 8
|
||||
target_w = w * 8
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
return values / weight
|
||||
|
||||
|
||||
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
|
||||
latents = latents.to(self.post_quant_conv.weight.dtype)
|
||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
|
||||
@@ -7,6 +7,7 @@ from .sd3_dit import SD3DiT
|
||||
from .flux_dit import FluxDiT
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .cog_dit import CogDiT
|
||||
from .hunyuan_video_dit import HunyuanVideoDiT
|
||||
|
||||
|
||||
|
||||
@@ -259,6 +260,14 @@ class GeneralLoRAFromPeft:
|
||||
return None
|
||||
|
||||
|
||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [HunyuanVideoDiT]
|
||||
self.lora_prefix = ["diffusion_model."]
|
||||
self.special_keys = {}
|
||||
|
||||
|
||||
class FluxLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -355,4 +364,4 @@ class FluxLoRAConverter:
|
||||
|
||||
|
||||
def get_lora_loaders():
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||
|
||||
@@ -7,10 +7,10 @@ import torch
|
||||
from transformers import LlamaModel
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
@@ -22,6 +22,13 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
self.dit: HunyuanVideoDiT = None
|
||||
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
|
||||
self.vram_management = False
|
||||
|
||||
|
||||
def enable_vram_management(self):
|
||||
self.vram_management = True
|
||||
self.enable_cpu_offload()
|
||||
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
@@ -38,10 +45,8 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.fetch_models(model_manager)
|
||||
# VRAM management is automatically enabled.
|
||||
if enable_vram_management:
|
||||
pipe.enable_cpu_offload()
|
||||
pipe.dit.enable_auto_offload(dtype=torch_dtype, device=device)
|
||||
pipe.enable_vram_management()
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -77,26 +82,34 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
embedded_guidance=6.0,
|
||||
cfg_scale=1.0,
|
||||
num_inference_steps=30,
|
||||
tile_size=(17, 30, 30),
|
||||
tile_stride=(12, 20, 20),
|
||||
progress_bar_cmd=lambda x: x,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Initialize noise
|
||||
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
self.load_models_to_device([])
|
||||
# Denoise
|
||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||
|
||||
# Inference
|
||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
||||
noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input)
|
||||
@@ -104,12 +117,16 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64)
|
||||
# decode
|
||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
|
||||
self.load_models_to_device([])
|
||||
frames = self.tensor2video(frames[0])
|
||||
|
||||
return frames
|
||||
|
||||
@@ -1,18 +1,42 @@
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(1.0, 0)
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
|
||||
|
||||
# Download models (automatically)
|
||||
download_models(["HunyuanVideo"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
"t2i_models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"t2i_models/HunyuanVideo/text_encoder_2",
|
||||
])
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(model_manager)
|
||||
prompt = 'A cat walks on the grass, realistic style.'
|
||||
frames = pipe(prompt)
|
||||
save_video(frames, 'test_video.mp4', fps=8, quality=5)
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
|
||||
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Enjoy!
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
video = pipe(prompt, seed=0)
|
||||
save_video(video, "video.mp4", fps=30, quality=5)
|
||||
|
||||
Reference in New Issue
Block a user