mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support hunyuanvideo v2v
This commit is contained in:
@@ -453,7 +453,7 @@ class HunyuanVideoVAEDecoder(nn.Module):
|
||||
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||
|
||||
for t, t_, h, h_, w, w_ in tqdm(tasks):
|
||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||
if t > 0:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
|
||||
|
||||
|
||||
@@ -192,12 +193,101 @@ class HunyuanVideoVAEEncoder(nn.Module):
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
||||
self.scaling_factor = 0.476986
|
||||
|
||||
|
||||
def forward(self, images):
|
||||
latents = self.encoder(images)
|
||||
latents = self.quant_conv(latents)
|
||||
# latents: (B C T H W)
|
||||
latents = latents[:, :16]
|
||||
latents = latents * self.scaling_factor
|
||||
return latents
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, T, H, W = data.shape
|
||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||
|
||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||
|
||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
size_t, size_h, size_w = tile_size
|
||||
stride_t, stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for t in range(0, T, stride_t):
|
||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||
tasks.append((t, t_, h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
torch_dtype = self.quant_conv.weight.dtype
|
||||
data_device = hidden_states.device
|
||||
computation_device = self.quant_conv.weight.device
|
||||
|
||||
weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
||||
|
||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||
if t > 0:
|
||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
|
||||
).to(dtype=torch_dtype, device=data_device)
|
||||
|
||||
target_t = 0 if t==0 else t // 4 + 1
|
||||
target_h = h // 8
|
||||
target_w = w // 8
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
return values / weight
|
||||
|
||||
|
||||
def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
|
||||
latents = latents.to(self.quant_conv.weight.dtype)
|
||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
|
||||
@@ -263,8 +263,8 @@ class GeneralLoRAFromPeft:
|
||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [HunyuanVideoDiT]
|
||||
self.lora_prefix = ["diffusion_model."]
|
||||
self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
|
||||
self.lora_prefix = ["diffusion_model.", "transformer."]
|
||||
self.special_keys = {}
|
||||
|
||||
|
||||
|
||||
@@ -72,16 +72,21 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
return frames
|
||||
|
||||
def encode_video(self, frames):
|
||||
# frames : (B, C, T, H, W)
|
||||
latents = self.vae_encoder(frames)
|
||||
|
||||
def encode_video(self, frames, tile_size=(17, 30, 30), tile_stride=(12, 20, 20)):
|
||||
tile_size = ((tile_size[0] - 1) * 4 + 1, tile_size[1] * 8, tile_size[2] * 8)
|
||||
tile_stride = (tile_stride[0] * 4, tile_stride[1] * 8, tile_stride[2] * 8)
|
||||
latents = self.vae_encoder.encode_video(frames, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_video=None,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
height=720,
|
||||
width=1280,
|
||||
@@ -94,8 +99,22 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
progress_bar_cmd=lambda x: x,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Initialize noise
|
||||
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
if input_video is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
input_video = self.preprocess_images(input_video)
|
||||
input_video = torch.stack(input_video, dim=2)
|
||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
||||
@@ -106,9 +125,6 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
@@ -126,9 +142,6 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
|
||||
Reference in New Issue
Block a user