mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
hunyuanvideo_vae_encoder
This commit is contained in:
@@ -125,7 +125,7 @@ class EncoderCausal3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -193,15 +193,11 @@ class HunyuanVideoVAEEncoder(nn.Module):
|
||||
)
|
||||
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
||||
|
||||
def encode_video(self, latents, use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64):
|
||||
if use_temporal_tiling:
|
||||
raise NotImplementedError
|
||||
if use_spatial_tiling:
|
||||
raise NotImplementedError
|
||||
# no tiling
|
||||
latents = self.decoder(latents)
|
||||
dec = self.quant_conv(latents)
|
||||
return dec
|
||||
def forward(self, images):
|
||||
latents = self.encoder(images)
|
||||
latents = self.quant_conv(latents)
|
||||
# latents: (B C T H W)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
|
||||
Reference in New Issue
Block a user