mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
@@ -453,7 +453,7 @@ class HunyuanVideoVAEDecoder(nn.Module):
|
|||||||
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||||
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||||
|
|
||||||
for t, t_, h, h_, w, w_ in tqdm(tasks):
|
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
||||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||||
if t > 0:
|
if t > 0:
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange, repeat
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
|
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
|
||||||
|
|
||||||
|
|
||||||
@@ -192,12 +193,101 @@ class HunyuanVideoVAEEncoder(nn.Module):
|
|||||||
gradient_checkpointing=gradient_checkpointing,
|
gradient_checkpointing=gradient_checkpointing,
|
||||||
)
|
)
|
||||||
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
||||||
|
self.scaling_factor = 0.476986
|
||||||
|
|
||||||
|
|
||||||
def forward(self, images):
|
def forward(self, images):
|
||||||
latents = self.encoder(images)
|
latents = self.encoder(images)
|
||||||
latents = self.quant_conv(latents)
|
latents = self.quant_conv(latents)
|
||||||
# latents: (B C T H W)
|
latents = latents[:, :16]
|
||||||
|
latents = latents * self.scaling_factor
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||||
|
x = torch.ones((length,))
|
||||||
|
if not left_bound:
|
||||||
|
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||||
|
if not right_bound:
|
||||||
|
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def build_mask(self, data, is_bound, border_width):
|
||||||
|
_, _, T, H, W = data.shape
|
||||||
|
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||||
|
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||||
|
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||||
|
|
||||||
|
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||||
|
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||||
|
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||||
|
|
||||||
|
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||||
|
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||||
|
B, C, T, H, W = hidden_states.shape
|
||||||
|
size_t, size_h, size_w = tile_size
|
||||||
|
stride_t, stride_h, stride_w = tile_stride
|
||||||
|
|
||||||
|
# Split tasks
|
||||||
|
tasks = []
|
||||||
|
for t in range(0, T, stride_t):
|
||||||
|
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||||
|
for h in range(0, H, stride_h):
|
||||||
|
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||||
|
for w in range(0, W, stride_w):
|
||||||
|
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||||
|
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||||
|
tasks.append((t, t_, h, h_, w, w_))
|
||||||
|
|
||||||
|
# Run
|
||||||
|
torch_dtype = self.quant_conv.weight.dtype
|
||||||
|
data_device = hidden_states.device
|
||||||
|
computation_device = self.quant_conv.weight.device
|
||||||
|
|
||||||
|
weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
||||||
|
values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
||||||
|
|
||||||
|
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||||
|
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||||
|
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||||
|
if t > 0:
|
||||||
|
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||||
|
|
||||||
|
mask = self.build_mask(
|
||||||
|
hidden_states_batch,
|
||||||
|
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||||
|
border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
|
||||||
|
).to(dtype=torch_dtype, device=data_device)
|
||||||
|
|
||||||
|
target_t = 0 if t==0 else t // 4 + 1
|
||||||
|
target_h = h // 8
|
||||||
|
target_w = w // 8
|
||||||
|
values[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
target_t: target_t + hidden_states_batch.shape[2],
|
||||||
|
target_h: target_h + hidden_states_batch.shape[3],
|
||||||
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
|
] += hidden_states_batch * mask
|
||||||
|
weight[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
target_t: target_t + hidden_states_batch.shape[2],
|
||||||
|
target_h: target_h + hidden_states_batch.shape[3],
|
||||||
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
|
] += mask
|
||||||
|
return values / weight
|
||||||
|
|
||||||
|
|
||||||
|
def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
|
||||||
|
latents = latents.to(self.quant_conv.weight.dtype)
|
||||||
|
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def state_dict_converter():
|
def state_dict_converter():
|
||||||
|
|||||||
@@ -263,8 +263,8 @@ class GeneralLoRAFromPeft:
|
|||||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.supported_model_classes = [HunyuanVideoDiT]
|
self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
|
||||||
self.lora_prefix = ["diffusion_model."]
|
self.lora_prefix = ["diffusion_model.", "transformer."]
|
||||||
self.special_keys = {}
|
self.special_keys = {}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -72,16 +72,21 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
frames = [Image.fromarray(frame) for frame in frames]
|
frames = [Image.fromarray(frame) for frame in frames]
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
def encode_video(self, frames):
|
|
||||||
# frames : (B, C, T, H, W)
|
def encode_video(self, frames, tile_size=(17, 30, 30), tile_stride=(12, 20, 20)):
|
||||||
latents = self.vae_encoder(frames)
|
tile_size = ((tile_size[0] - 1) * 4 + 1, tile_size[1] * 8, tile_size[2] * 8)
|
||||||
|
tile_stride = (tile_stride[0] * 4, tile_stride[1] * 8, tile_stride[2] * 8)
|
||||||
|
latents = self.vae_encoder.encode_video(frames, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
|
input_video=None,
|
||||||
|
denoising_strength=1.0,
|
||||||
seed=None,
|
seed=None,
|
||||||
height=720,
|
height=720,
|
||||||
width=1280,
|
width=1280,
|
||||||
@@ -94,8 +99,22 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
progress_bar_cmd=lambda x: x,
|
progress_bar_cmd=lambda x: x,
|
||||||
progress_bar_st=None,
|
progress_bar_st=None,
|
||||||
):
|
):
|
||||||
|
# Tiler parameters
|
||||||
|
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
|
||||||
|
# Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||||
|
|
||||||
# Initialize noise
|
# Initialize noise
|
||||||
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||||
|
if input_video is not None:
|
||||||
|
self.load_models_to_device(['vae_encoder'])
|
||||||
|
input_video = self.preprocess_images(input_video)
|
||||||
|
input_video = torch.stack(input_video, dim=2)
|
||||||
|
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
|
else:
|
||||||
|
latents = noise
|
||||||
|
|
||||||
# Encode prompts
|
# Encode prompts
|
||||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
||||||
@@ -106,9 +125,6 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
# Extra input
|
# Extra input
|
||||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||||
|
|
||||||
# Scheduler
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
@@ -126,9 +142,6 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||||
|
|
||||||
# Tiler parameters
|
|
||||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
self.load_models_to_device(['vae_decoder'])
|
self.load_models_to_device(['vae_decoder'])
|
||||||
|
|||||||
@@ -17,3 +17,7 @@ https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9
|
|||||||
Video generated by [hunyuanvideo_6G.py](hunyuanvideo_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
Video generated by [hunyuanvideo_6G.py](hunyuanvideo_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
|
https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
|
||||||
|
|
||||||
|
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
|
||||||
|
|||||||
50
examples/HunyuanVideo/hunyuanvideo_v2v_6G.py
Normal file
50
examples/HunyuanVideo/hunyuanvideo_v2v_6G.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import torch
|
||||||
|
torch.cuda.set_per_process_memory_fraction(6/80, 0)
|
||||||
|
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video, FlowMatchScheduler
|
||||||
|
|
||||||
|
|
||||||
|
download_models(["HunyuanVideo"])
|
||||||
|
model_manager = ModelManager()
|
||||||
|
|
||||||
|
# The DiT model is loaded in bfloat16.
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||||
|
],
|
||||||
|
torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization.
|
||||||
|
device="cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The other modules are loaded in float16.
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||||
|
"models/HunyuanVideo/text_encoder_2",
|
||||||
|
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||||
|
],
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device="cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||||
|
# Example LoRA: https://civitai.com/models/1032126/walking-animation-hunyuan-video
|
||||||
|
model_manager.load_lora("models/lora/kxsr_walking_anim_v1-5.safetensors", lora_alpha=1.0)
|
||||||
|
|
||||||
|
# The computation device is "cuda".
|
||||||
|
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||||
|
model_manager,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda"
|
||||||
|
)
|
||||||
|
# This LoRA requires shift=9.0.
|
||||||
|
pipe.scheduler = FlowMatchScheduler(shift=9.0, sigma_min=0.0, extra_one_step=True)
|
||||||
|
|
||||||
|
# Text-to-video
|
||||||
|
prompt = f"kxsr, full body, no crop. A girl is walking. CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||||
|
video = pipe(prompt, seed=1, height=512, width=384, num_frames=129, num_inference_steps=18, tile_size=(17, 16, 16), tile_stride=(12, 12, 12))
|
||||||
|
save_video(video, f"video.mp4", fps=30, quality=6)
|
||||||
|
|
||||||
|
# Video-to-video
|
||||||
|
prompt = f"kxsr, full body, no crop. A girl is walking. CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, purple dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||||
|
video = pipe(prompt, seed=1, height=512, width=384, num_frames=129, num_inference_steps=18, tile_size=(17, 16, 16), tile_stride=(12, 12, 12), input_video=video, denoising_strength=0.85)
|
||||||
|
save_video(video, f"video_edited.mp4", fps=30, quality=6)
|
||||||
Reference in New Issue
Block a user