mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
308 lines
11 KiB
Python
308 lines
11 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange, repeat
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
|
|
|
|
|
|
class DownsampleCausal3D(nn.Module):
|
|
|
|
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
|
|
super().__init__()
|
|
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.conv(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class DownEncoderBlockCausal3D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
dropout=0.0,
|
|
num_layers=1,
|
|
eps=1e-6,
|
|
num_groups=32,
|
|
add_downsample=True,
|
|
downsample_stride=2,
|
|
):
|
|
|
|
super().__init__()
|
|
resnets = []
|
|
for i in range(num_layers):
|
|
cur_in_channel = in_channels if i == 0 else out_channels
|
|
resnets.append(
|
|
ResnetBlockCausal3D(
|
|
in_channels=cur_in_channel,
|
|
out_channels=out_channels,
|
|
groups=num_groups,
|
|
dropout=dropout,
|
|
eps=eps,
|
|
))
|
|
self.resnets = nn.ModuleList(resnets)
|
|
|
|
self.downsamplers = None
|
|
if add_downsample:
|
|
self.downsamplers = nn.ModuleList([DownsampleCausal3D(
|
|
out_channels,
|
|
out_channels,
|
|
stride=downsample_stride,
|
|
)])
|
|
|
|
def forward(self, hidden_states):
|
|
for resnet in self.resnets:
|
|
hidden_states = resnet(hidden_states)
|
|
|
|
if self.downsamplers is not None:
|
|
for downsampler in self.downsamplers:
|
|
hidden_states = downsampler(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class EncoderCausal3D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int = 3,
|
|
out_channels: int = 16,
|
|
eps=1e-6,
|
|
dropout=0.0,
|
|
block_out_channels=[128, 256, 512, 512],
|
|
layers_per_block=2,
|
|
num_groups=32,
|
|
time_compression_ratio: int = 4,
|
|
spatial_compression_ratio: int = 8,
|
|
gradient_checkpointing=False,
|
|
):
|
|
super().__init__()
|
|
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
|
self.down_blocks = nn.ModuleList([])
|
|
|
|
# down
|
|
output_channel = block_out_channels[0]
|
|
for i in range(len(block_out_channels)):
|
|
input_channel = output_channel
|
|
output_channel = block_out_channels[i]
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
|
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
|
|
|
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
|
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
|
|
|
|
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
|
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
|
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
|
down_block = DownEncoderBlockCausal3D(
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
dropout=dropout,
|
|
num_layers=layers_per_block,
|
|
eps=eps,
|
|
num_groups=num_groups,
|
|
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
|
downsample_stride=downsample_stride,
|
|
)
|
|
self.down_blocks.append(down_block)
|
|
|
|
# mid
|
|
self.mid_block = UNetMidBlockCausal3D(
|
|
in_channels=block_out_channels[-1],
|
|
dropout=dropout,
|
|
eps=eps,
|
|
num_groups=num_groups,
|
|
attention_head_dim=block_out_channels[-1],
|
|
)
|
|
# out
|
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
|
|
self.conv_act = nn.SiLU()
|
|
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
|
|
|
|
self.gradient_checkpointing = gradient_checkpointing
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.conv_in(hidden_states)
|
|
if self.training and self.gradient_checkpointing:
|
|
|
|
def create_custom_forward(module):
|
|
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
# down
|
|
for down_block in self.down_blocks:
|
|
torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(down_block),
|
|
hidden_states,
|
|
use_reentrant=False,
|
|
)
|
|
# middle
|
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(self.mid_block),
|
|
hidden_states,
|
|
use_reentrant=False,
|
|
)
|
|
else:
|
|
# down
|
|
for down_block in self.down_blocks:
|
|
hidden_states = down_block(hidden_states)
|
|
# middle
|
|
hidden_states = self.mid_block(hidden_states)
|
|
# post-process
|
|
hidden_states = self.conv_norm_out(hidden_states)
|
|
hidden_states = self.conv_act(hidden_states)
|
|
hidden_states = self.conv_out(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class HunyuanVideoVAEEncoder(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels=3,
|
|
out_channels=16,
|
|
eps=1e-6,
|
|
dropout=0.0,
|
|
block_out_channels=[128, 256, 512, 512],
|
|
layers_per_block=2,
|
|
num_groups=32,
|
|
time_compression_ratio=4,
|
|
spatial_compression_ratio=8,
|
|
gradient_checkpointing=False,
|
|
):
|
|
super().__init__()
|
|
self.encoder = EncoderCausal3D(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
eps=eps,
|
|
dropout=dropout,
|
|
block_out_channels=block_out_channels,
|
|
layers_per_block=layers_per_block,
|
|
num_groups=num_groups,
|
|
time_compression_ratio=time_compression_ratio,
|
|
spatial_compression_ratio=spatial_compression_ratio,
|
|
gradient_checkpointing=gradient_checkpointing,
|
|
)
|
|
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
|
self.scaling_factor = 0.476986
|
|
|
|
|
|
def forward(self, images):
|
|
latents = self.encoder(images)
|
|
latents = self.quant_conv(latents)
|
|
latents = latents[:, :16]
|
|
latents = latents * self.scaling_factor
|
|
return latents
|
|
|
|
|
|
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
|
x = torch.ones((length,))
|
|
if not left_bound:
|
|
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
|
if not right_bound:
|
|
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
|
return x
|
|
|
|
|
|
def build_mask(self, data, is_bound, border_width):
|
|
_, _, T, H, W = data.shape
|
|
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
|
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
|
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
|
|
|
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
|
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
|
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
|
|
|
mask = torch.stack([t, h, w]).min(dim=0).values
|
|
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
|
return mask
|
|
|
|
|
|
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
|
B, C, T, H, W = hidden_states.shape
|
|
size_t, size_h, size_w = tile_size
|
|
stride_t, stride_h, stride_w = tile_stride
|
|
|
|
# Split tasks
|
|
tasks = []
|
|
for t in range(0, T, stride_t):
|
|
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
|
for h in range(0, H, stride_h):
|
|
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
|
for w in range(0, W, stride_w):
|
|
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
|
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
|
tasks.append((t, t_, h, h_, w, w_))
|
|
|
|
# Run
|
|
torch_dtype = self.quant_conv.weight.dtype
|
|
data_device = hidden_states.device
|
|
computation_device = self.quant_conv.weight.device
|
|
|
|
weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
|
values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
|
|
|
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
|
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
|
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
|
if t > 0:
|
|
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
|
|
|
mask = self.build_mask(
|
|
hidden_states_batch,
|
|
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
|
border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
|
|
).to(dtype=torch_dtype, device=data_device)
|
|
|
|
target_t = 0 if t==0 else t // 4 + 1
|
|
target_h = h // 8
|
|
target_w = w // 8
|
|
values[
|
|
:,
|
|
:,
|
|
target_t: target_t + hidden_states_batch.shape[2],
|
|
target_h: target_h + hidden_states_batch.shape[3],
|
|
target_w: target_w + hidden_states_batch.shape[4],
|
|
] += hidden_states_batch * mask
|
|
weight[
|
|
:,
|
|
:,
|
|
target_t: target_t + hidden_states_batch.shape[2],
|
|
target_h: target_h + hidden_states_batch.shape[3],
|
|
target_w: target_w + hidden_states_batch.shape[4],
|
|
] += mask
|
|
return values / weight
|
|
|
|
|
|
def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
|
|
latents = latents.to(self.quant_conv.weight.dtype)
|
|
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
|
|
|
|
|
@staticmethod
|
|
def state_dict_converter():
|
|
return HunyuanVideoVAEEncoderStateDictConverter()
|
|
|
|
|
|
class HunyuanVideoVAEEncoderStateDictConverter:
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def from_diffusers(self, state_dict):
|
|
state_dict_ = {}
|
|
for name in state_dict:
|
|
if name.startswith('encoder.') or name.startswith('quant_conv.'):
|
|
state_dict_[name] = state_dict[name]
|
|
return state_dict_
|