mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
hunyuanvideo
This commit is contained in:
@@ -643,12 +643,14 @@ preset_models_on_modelscope = {
|
|||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
||||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae")
|
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
||||||
|
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
||||||
],
|
],
|
||||||
"load_path": [
|
"load_path": [
|
||||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||||
"models/HunyuanVideo/text_encoder_2",
|
"models/HunyuanVideo/text_encoder_2",
|
||||||
"models/HunyuanVideo/vae/pytorch_model.pt"
|
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||||
|
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
|
||||||
class CausalConv3d(nn.Module):
|
class CausalConv3d(nn.Module):
|
||||||
@@ -393,16 +395,99 @@ class HunyuanVideoVAEDecoder(nn.Module):
|
|||||||
gradient_checkpointing=gradient_checkpointing,
|
gradient_checkpointing=gradient_checkpointing,
|
||||||
)
|
)
|
||||||
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||||
|
self.scaling_factor = 0.476986
|
||||||
|
|
||||||
def decode_video(self, latents, use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64):
|
|
||||||
if use_temporal_tiling:
|
def forward(self, latents):
|
||||||
raise NotImplementedError
|
latents = latents / self.scaling_factor
|
||||||
if use_spatial_tiling:
|
|
||||||
raise NotImplementedError
|
|
||||||
# no tiling
|
|
||||||
latents = self.post_quant_conv(latents)
|
latents = self.post_quant_conv(latents)
|
||||||
dec = self.decoder(latents)
|
dec = self.decoder(latents)
|
||||||
return dec
|
return dec
|
||||||
|
|
||||||
|
|
||||||
|
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||||
|
x = torch.ones((length,))
|
||||||
|
if not left_bound:
|
||||||
|
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||||
|
if not right_bound:
|
||||||
|
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def build_mask(self, data, is_bound, border_width):
|
||||||
|
_, _, T, H, W = data.shape
|
||||||
|
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||||
|
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||||
|
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||||
|
|
||||||
|
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||||
|
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||||
|
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||||
|
|
||||||
|
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||||
|
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||||
|
B, C, T, H, W = hidden_states.shape
|
||||||
|
size_t, size_h, size_w = tile_size
|
||||||
|
stride_t, stride_h, stride_w = tile_stride
|
||||||
|
|
||||||
|
# Split tasks
|
||||||
|
tasks = []
|
||||||
|
for t in range(0, T, stride_t):
|
||||||
|
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||||
|
for h in range(0, H, stride_h):
|
||||||
|
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||||
|
for w in range(0, W, stride_w):
|
||||||
|
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||||
|
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||||
|
tasks.append((t, t_, h, h_, w, w_))
|
||||||
|
|
||||||
|
# Run
|
||||||
|
torch_dtype = self.post_quant_conv.weight.dtype
|
||||||
|
data_device = hidden_states.device
|
||||||
|
computation_device = self.post_quant_conv.weight.device
|
||||||
|
|
||||||
|
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||||
|
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||||
|
|
||||||
|
for t, t_, h, h_, w, w_ in tqdm(tasks):
|
||||||
|
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||||
|
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||||
|
if t > 0:
|
||||||
|
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||||
|
|
||||||
|
mask = self.build_mask(
|
||||||
|
hidden_states_batch,
|
||||||
|
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||||
|
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
|
||||||
|
).to(dtype=torch_dtype, device=data_device)
|
||||||
|
|
||||||
|
target_t = 0 if t==0 else t * 4 + 1
|
||||||
|
target_h = h * 8
|
||||||
|
target_w = w * 8
|
||||||
|
values[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
target_t: target_t + hidden_states_batch.shape[2],
|
||||||
|
target_h: target_h + hidden_states_batch.shape[3],
|
||||||
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
|
] += hidden_states_batch * mask
|
||||||
|
weight[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
target_t: target_t + hidden_states_batch.shape[2],
|
||||||
|
target_h: target_h + hidden_states_batch.shape[3],
|
||||||
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
|
] += mask
|
||||||
|
return values / weight
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
|
||||||
|
latents = latents.to(self.post_quant_conv.weight.dtype)
|
||||||
|
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def state_dict_converter():
|
def state_dict_converter():
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from .sd3_dit import SD3DiT
|
|||||||
from .flux_dit import FluxDiT
|
from .flux_dit import FluxDiT
|
||||||
from .hunyuan_dit import HunyuanDiT
|
from .hunyuan_dit import HunyuanDiT
|
||||||
from .cog_dit import CogDiT
|
from .cog_dit import CogDiT
|
||||||
|
from .hunyuan_video_dit import HunyuanVideoDiT
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -259,6 +260,14 @@ class GeneralLoRAFromPeft:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.supported_model_classes = [HunyuanVideoDiT]
|
||||||
|
self.lora_prefix = ["diffusion_model."]
|
||||||
|
self.special_keys = {}
|
||||||
|
|
||||||
|
|
||||||
class FluxLoRAConverter:
|
class FluxLoRAConverter:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
@@ -355,4 +364,4 @@ class FluxLoRAConverter:
|
|||||||
|
|
||||||
|
|
||||||
def get_lora_loaders():
|
def get_lora_loaders():
|
||||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||||
|
|||||||
@@ -7,10 +7,10 @@ import torch
|
|||||||
from transformers import LlamaModel
|
from transformers import LlamaModel
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoPipeline(BasePipeline):
|
class HunyuanVideoPipeline(BasePipeline):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||||
@@ -22,6 +22,13 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
self.dit: HunyuanVideoDiT = None
|
self.dit: HunyuanVideoDiT = None
|
||||||
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
||||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
|
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
|
||||||
|
self.vram_management = False
|
||||||
|
|
||||||
|
|
||||||
|
def enable_vram_management(self):
|
||||||
|
self.vram_management = True
|
||||||
|
self.enable_cpu_offload()
|
||||||
|
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager):
|
def fetch_models(self, model_manager: ModelManager):
|
||||||
@@ -38,10 +45,8 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||||
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
pipe.fetch_models(model_manager)
|
pipe.fetch_models(model_manager)
|
||||||
# VRAM management is automatically enabled.
|
|
||||||
if enable_vram_management:
|
if enable_vram_management:
|
||||||
pipe.enable_cpu_offload()
|
pipe.enable_vram_management()
|
||||||
pipe.dit.enable_auto_offload(dtype=torch_dtype, device=device)
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
@@ -77,26 +82,34 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
embedded_guidance=6.0,
|
embedded_guidance=6.0,
|
||||||
cfg_scale=1.0,
|
cfg_scale=1.0,
|
||||||
num_inference_steps=30,
|
num_inference_steps=30,
|
||||||
|
tile_size=(17, 30, 30),
|
||||||
|
tile_stride=(12, 20, 20),
|
||||||
progress_bar_cmd=lambda x: x,
|
progress_bar_cmd=lambda x: x,
|
||||||
progress_bar_st=None,
|
progress_bar_st=None,
|
||||||
):
|
):
|
||||||
|
# Initialize noise
|
||||||
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||||
|
|
||||||
|
# Encode prompts
|
||||||
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
|
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||||
|
|
||||||
|
# Extra input
|
||||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||||
|
|
||||||
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
self.load_models_to_device([])
|
# Denoise
|
||||||
|
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
timestep = timestep.unsqueeze(0).to(self.device)
|
||||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
||||||
noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input)
|
noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input)
|
noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input)
|
||||||
@@ -104,12 +117,16 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
else:
|
else:
|
||||||
noise_pred = noise_pred_posi
|
noise_pred = noise_pred_posi
|
||||||
|
|
||||||
|
# Scheduler
|
||||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||||
|
|
||||||
# Tiler parameters
|
# Tiler parameters
|
||||||
tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64)
|
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
# decode
|
|
||||||
|
# Decode
|
||||||
self.load_models_to_device(['vae_decoder'])
|
self.load_models_to_device(['vae_decoder'])
|
||||||
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
|
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
|
||||||
|
self.load_models_to_device([])
|
||||||
frames = self.tensor2video(frames[0])
|
frames = self.tensor2video(frames[0])
|
||||||
|
|
||||||
return frames
|
return frames
|
||||||
|
|||||||
@@ -1,18 +1,42 @@
|
|||||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
|
||||||
import torch
|
import torch
|
||||||
|
torch.cuda.set_per_process_memory_fraction(1.0, 0)
|
||||||
|
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||||
|
|
||||||
|
|
||||||
# Download models (automatically)
|
|
||||||
download_models(["HunyuanVideo"])
|
download_models(["HunyuanVideo"])
|
||||||
|
model_manager = ModelManager()
|
||||||
|
|
||||||
# Load models
|
# The DiT model is loaded in bfloat16.
|
||||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
model_manager.load_models(
|
||||||
model_manager.load_models([
|
[
|
||||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||||
"t2i_models/HunyuanVideo/text_encoder/model.safetensors",
|
],
|
||||||
"t2i_models/HunyuanVideo/text_encoder_2",
|
torch_dtype=torch.bfloat16,
|
||||||
])
|
device="cpu"
|
||||||
pipe = HunyuanVideoPipeline.from_model_manager(model_manager)
|
)
|
||||||
prompt = 'A cat walks on the grass, realistic style.'
|
|
||||||
frames = pipe(prompt)
|
# The other modules are loaded in float16.
|
||||||
save_video(frames, 'test_video.mp4', fps=8, quality=5)
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||||
|
"models/HunyuanVideo/text_encoder_2",
|
||||||
|
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||||
|
],
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device="cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||||
|
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
|
||||||
|
|
||||||
|
# The computation device is "cuda".
|
||||||
|
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||||
|
model_manager,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enjoy!
|
||||||
|
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||||
|
video = pipe(prompt, seed=0)
|
||||||
|
save_video(video, "video.mp4", fps=30, quality=5)
|
||||||
|
|||||||
Reference in New Issue
Block a user