hunyuanvideo

This commit is contained in:
Artiprocher
2024-12-18 16:43:06 +08:00
parent 447adef472
commit e5099f4e74
5 changed files with 168 additions and 31 deletions

View File

@@ -643,12 +643,14 @@ preset_models_on_modelscope = {
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae")
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
],
"load_path": [
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideo/text_encoder_2",
"models/HunyuanVideo/vae/pytorch_model.pt"
"models/HunyuanVideo/vae/pytorch_model.pt",
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
],
},
}

View File

@@ -3,6 +3,8 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np
from tqdm import tqdm
from einops import repeat
class CausalConv3d(nn.Module):
@@ -393,16 +395,99 @@ class HunyuanVideoVAEDecoder(nn.Module):
gradient_checkpointing=gradient_checkpointing,
)
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.scaling_factor = 0.476986
def decode_video(self, latents, use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64):
if use_temporal_tiling:
raise NotImplementedError
if use_spatial_tiling:
raise NotImplementedError
# no tiling
def forward(self, latents):
latents = latents / self.scaling_factor
latents = self.post_quant_conv(latents)
dec = self.decoder(latents)
return dec
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, T, H, W = data.shape
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
mask = torch.stack([t, h, w]).min(dim=0).values
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tile_forward(self, hidden_states, tile_size, tile_stride):
B, C, T, H, W = hidden_states.shape
size_t, size_h, size_w = tile_size
stride_t, stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for t in range(0, T, stride_t):
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
t_, h_, w_ = t + size_t, h + size_h, w + size_w
tasks.append((t, t_, h, h_, w, w_))
# Run
torch_dtype = self.post_quant_conv.weight.dtype
data_device = hidden_states.device
computation_device = self.post_quant_conv.weight.device
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
for t, t_, h, h_, w, w_ in tqdm(tasks):
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
if t > 0:
hidden_states_batch = hidden_states_batch[:, :, 1:]
mask = self.build_mask(
hidden_states_batch,
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
).to(dtype=torch_dtype, device=data_device)
target_t = 0 if t==0 else t * 4 + 1
target_h = h * 8
target_w = w * 8
values[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
return values / weight
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
latents = latents.to(self.post_quant_conv.weight.dtype)
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
@staticmethod
def state_dict_converter():

View File

@@ -7,6 +7,7 @@ from .sd3_dit import SD3DiT
from .flux_dit import FluxDiT
from .hunyuan_dit import HunyuanDiT
from .cog_dit import CogDiT
from .hunyuan_video_dit import HunyuanVideoDiT
@@ -259,6 +260,14 @@ class GeneralLoRAFromPeft:
return None
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [HunyuanVideoDiT]
self.lora_prefix = ["diffusion_model."]
self.special_keys = {}
class FluxLoRAConverter:
def __init__(self):
pass
@@ -355,4 +364,4 @@ class FluxLoRAConverter:
def get_lora_loaders():
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()]
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]

View File

@@ -7,10 +7,10 @@ import torch
from transformers import LlamaModel
from einops import rearrange
import numpy as np
from tqdm import tqdm
from PIL import Image
class HunyuanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
@@ -22,6 +22,13 @@ class HunyuanVideoPipeline(BasePipeline):
self.dit: HunyuanVideoDiT = None
self.vae_decoder: HunyuanVideoVAEDecoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
self.vram_management = False
def enable_vram_management(self):
self.vram_management = True
self.enable_cpu_offload()
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
def fetch_models(self, model_manager: ModelManager):
@@ -38,10 +45,8 @@ class HunyuanVideoPipeline(BasePipeline):
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# VRAM management is automatically enabled.
if enable_vram_management:
pipe.enable_cpu_offload()
pipe.dit.enable_auto_offload(dtype=torch_dtype, device=device)
pipe.enable_vram_management()
return pipe
@@ -77,26 +82,34 @@ class HunyuanVideoPipeline(BasePipeline):
embedded_guidance=6.0,
cfg_scale=1.0,
num_inference_steps=30,
tile_size=(17, 30, 30),
tile_stride=(12, 20, 20),
progress_bar_cmd=lambda x: x,
progress_bar_st=None,
):
# Initialize noise
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Scheduler
self.scheduler.set_timesteps(num_inference_steps)
self.load_models_to_device([])
# Denoise
self.load_models_to_device([] if self.vram_management else ["dit"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
# Inference
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input)
@@ -104,12 +117,16 @@ class HunyuanVideoPipeline(BasePipeline):
else:
noise_pred = noise_pred_posi
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Tiler parameters
tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64)
# decode
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
# Decode
self.load_models_to_device(['vae_decoder'])
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
self.load_models_to_device([])
frames = self.tensor2video(frames[0])
return frames

View File

@@ -1,18 +1,42 @@
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
import torch
torch.cuda.set_per_process_memory_fraction(1.0, 0)
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
# Download models (automatically)
download_models(["HunyuanVideo"])
model_manager = ModelManager()
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
model_manager.load_models([
"models/HunyuanVideo/vae/pytorch_model.pt",
"t2i_models/HunyuanVideo/text_encoder/model.safetensors",
"t2i_models/HunyuanVideo/text_encoder_2",
])
pipe = HunyuanVideoPipeline.from_model_manager(model_manager)
prompt = 'A cat walks on the grass, realistic style.'
frames = pipe(prompt)
save_video(frames, 'test_video.mp4', fps=8, quality=5)
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16,
device="cpu"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideo/text_encoder_2",
"models/HunyuanVideo/vae/pytorch_model.pt",
],
torch_dtype=torch.float16,
device="cpu"
)
# We support LoRA inference. You can use the following code to load your LoRA model.
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(
model_manager,
torch_dtype=torch.bfloat16,
device="cuda"
)
# Enjoy!
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
video = pipe(prompt, seed=0)
save_video(video, "video.mp4", fps=30, quality=5)