hunyuanvideo_vae_encoder

This commit is contained in:
mi804
2024-12-18 19:03:04 +08:00
parent 8e06cac0df
commit 3f410b0b77
2 changed files with 13 additions and 12 deletions

View File

@@ -125,7 +125,7 @@ class EncoderCausal3D(nn.Module):
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: def forward(self, hidden_states):
hidden_states = self.conv_in(hidden_states) hidden_states = self.conv_in(hidden_states)
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
@@ -193,15 +193,11 @@ class HunyuanVideoVAEEncoder(nn.Module):
) )
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1) self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
def encode_video(self, latents, use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64): def forward(self, images):
if use_temporal_tiling: latents = self.encoder(images)
raise NotImplementedError latents = self.quant_conv(latents)
if use_spatial_tiling: # latents: (B C T H W)
raise NotImplementedError return latents
# no tiling
latents = self.decoder(latents)
dec = self.quant_conv(latents)
return dec
@staticmethod @staticmethod
def state_dict_converter(): def state_dict_converter():

View File

@@ -1,4 +1,4 @@
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder
from ..models.hunyuan_video_dit import HunyuanVideoDiT from ..models.hunyuan_video_dit import HunyuanVideoDiT
from ..schedulers.flow_match import FlowMatchScheduler from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline from .base import BasePipeline
@@ -21,7 +21,8 @@ class HunyuanVideoPipeline(BasePipeline):
self.text_encoder_2: LlamaModel = None self.text_encoder_2: LlamaModel = None
self.dit: HunyuanVideoDiT = None self.dit: HunyuanVideoDiT = None
self.vae_decoder: HunyuanVideoVAEDecoder = None self.vae_decoder: HunyuanVideoVAEDecoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder'] self.vae_encoder: HunyuanVideoVAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
self.vram_management = False self.vram_management = False
@@ -70,6 +71,10 @@ class HunyuanVideoPipeline(BasePipeline):
frames = [Image.fromarray(frame) for frame in frames] frames = [Image.fromarray(frame) for frame in frames]
return frames return frames
def encode_video(self, frames):
# frames : (B, C, T, H, W)
latents = self.vae_encoder(frames)
return latents
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(