mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -44,6 +44,7 @@ from ..models.cog_dit import CogDiT
|
||||
from ..models.omnigen import OmniGenTransformer
|
||||
|
||||
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
||||
|
||||
from ..extensions.RIFE import IFNet
|
||||
from ..extensions.ESRGAN import RRDBNet
|
||||
@@ -97,7 +98,7 @@ model_loader_configs = [
|
||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
||||
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
||||
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder"], [HunyuanVideoVAEDecoder], "diffusers"),
|
||||
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
|
||||
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
|
||||
217
diffsynth/models/hunyuan_video_vae_encoder.py
Normal file
217
diffsynth/models/hunyuan_video_vae_encoder.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
|
||||
|
||||
|
||||
class DownsampleCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
|
||||
super().__init__()
|
||||
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DownEncoderBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
add_downsample=True,
|
||||
downsample_stride=2,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
cur_in_channel = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=cur_in_channel,
|
||||
out_channels=out_channels,
|
||||
groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
))
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.downsamplers = None
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList([DownsampleCausal3D(
|
||||
out_channels,
|
||||
out_channels,
|
||||
stride=downsample_stride,
|
||||
)])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EncoderCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 16,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio: int = 4,
|
||||
spatial_compression_ratio: int = 8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
||||
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
||||
|
||||
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
||||
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
|
||||
|
||||
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
||||
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
||||
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
||||
down_block = DownEncoderBlockCausal3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
||||
downsample_stride=downsample_stride,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlockCausal3D(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
)
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
# middle
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
# middle
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
# post-process
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoVAEEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = EncoderCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=eps,
|
||||
dropout=dropout,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
num_groups=num_groups,
|
||||
time_compression_ratio=time_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, images):
|
||||
latents = self.encoder(images)
|
||||
latents = self.quant_conv(latents)
|
||||
# latents: (B C T H W)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class HunyuanVideoVAEEncoderStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith('encoder.') or name.startswith('quant_conv.'):
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -36,6 +36,7 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||
from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
||||
|
||||
from .flux_dit import FluxDiT
|
||||
from .flux_text_encoder import FluxTextEncoder2
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
|
||||
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder
|
||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
@@ -21,7 +21,8 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
self.text_encoder_2: HunyuanVideoLLMEncoder = None
|
||||
self.dit: HunyuanVideoDiT = None
|
||||
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
|
||||
self.vae_encoder: HunyuanVideoVAEEncoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
|
||||
self.vram_management = False
|
||||
|
||||
|
||||
@@ -37,6 +38,7 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
|
||||
self.dit = model_manager.fetch_model("hunyuan_video_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("hunyuan_video_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||
|
||||
|
||||
@@ -70,6 +72,10 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
return frames
|
||||
|
||||
def encode_video(self, frames):
|
||||
# frames : (B, C, T, H, W)
|
||||
latents = self.vae_encoder(frames)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
|
||||
Reference in New Issue
Block a user