From 3f410b0b77474bfb39ce57345cd789315b1bceab Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 18 Dec 2024 19:03:04 +0800 Subject: [PATCH] hunyuanvideo_vae_encoder --- diffsynth/models/hunyuan_video_vae_encoder.py | 16 ++++++---------- diffsynth/pipelines/hunyuan_video.py | 9 +++++++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/diffsynth/models/hunyuan_video_vae_encoder.py b/diffsynth/models/hunyuan_video_vae_encoder.py index 4221186..ec7fd14 100644 --- a/diffsynth/models/hunyuan_video_vae_encoder.py +++ b/diffsynth/models/hunyuan_video_vae_encoder.py @@ -125,7 +125,7 @@ class EncoderCausal3D(nn.Module): self.gradient_checkpointing = gradient_checkpointing - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + def forward(self, hidden_states): hidden_states = self.conv_in(hidden_states) if self.training and self.gradient_checkpointing: @@ -193,15 +193,11 @@ class HunyuanVideoVAEEncoder(nn.Module): ) self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1) - def encode_video(self, latents, use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64): - if use_temporal_tiling: - raise NotImplementedError - if use_spatial_tiling: - raise NotImplementedError - # no tiling - latents = self.decoder(latents) - dec = self.quant_conv(latents) - return dec + def forward(self, images): + latents = self.encoder(images) + latents = self.quant_conv(latents) + # latents: (B C T H W) + return latents @staticmethod def state_dict_converter(): diff --git a/diffsynth/pipelines/hunyuan_video.py b/diffsynth/pipelines/hunyuan_video.py index 0df8942..bf65606 100644 --- a/diffsynth/pipelines/hunyuan_video.py +++ b/diffsynth/pipelines/hunyuan_video.py @@ -1,4 +1,4 @@ -from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder +from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder from ..models.hunyuan_video_dit import HunyuanVideoDiT from ..schedulers.flow_match import FlowMatchScheduler from .base import BasePipeline @@ -21,7 +21,8 @@ class HunyuanVideoPipeline(BasePipeline): self.text_encoder_2: LlamaModel = None self.dit: HunyuanVideoDiT = None self.vae_decoder: HunyuanVideoVAEDecoder = None - self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder'] + self.vae_encoder: HunyuanVideoVAEEncoder = None + self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder'] self.vram_management = False @@ -70,6 +71,10 @@ class HunyuanVideoPipeline(BasePipeline): frames = [Image.fromarray(frame) for frame in frames] return frames + def encode_video(self, frames): + # frames : (B, C, T, H, W) + latents = self.vae_encoder(frames) + return latents @torch.no_grad() def __call__(