mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
hunyuanvideo_vae_encoder
This commit is contained in:
@@ -125,7 +125,7 @@ class EncoderCausal3D(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = gradient_checkpointing
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
def forward(self, hidden_states):
|
||||||
hidden_states = self.conv_in(hidden_states)
|
hidden_states = self.conv_in(hidden_states)
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
@@ -193,15 +193,11 @@ class HunyuanVideoVAEEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
||||||
|
|
||||||
def encode_video(self, latents, use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64):
|
def forward(self, images):
|
||||||
if use_temporal_tiling:
|
latents = self.encoder(images)
|
||||||
raise NotImplementedError
|
latents = self.quant_conv(latents)
|
||||||
if use_spatial_tiling:
|
# latents: (B C T H W)
|
||||||
raise NotImplementedError
|
return latents
|
||||||
# no tiling
|
|
||||||
latents = self.decoder(latents)
|
|
||||||
dec = self.quant_conv(latents)
|
|
||||||
return dec
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def state_dict_converter():
|
def state_dict_converter():
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
|
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder
|
||||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||||
from ..schedulers.flow_match import FlowMatchScheduler
|
from ..schedulers.flow_match import FlowMatchScheduler
|
||||||
from .base import BasePipeline
|
from .base import BasePipeline
|
||||||
@@ -21,7 +21,8 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
self.text_encoder_2: LlamaModel = None
|
self.text_encoder_2: LlamaModel = None
|
||||||
self.dit: HunyuanVideoDiT = None
|
self.dit: HunyuanVideoDiT = None
|
||||||
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
||||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
|
self.vae_encoder: HunyuanVideoVAEEncoder = None
|
||||||
|
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
|
||||||
self.vram_management = False
|
self.vram_management = False
|
||||||
|
|
||||||
|
|
||||||
@@ -70,6 +71,10 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
frames = [Image.fromarray(frame) for frame in frames]
|
frames = [Image.fromarray(frame) for frame in frames]
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
def encode_video(self, frames):
|
||||||
|
# frames : (B, C, T, H, W)
|
||||||
|
latents = self.vae_encoder(frames)
|
||||||
|
return latents
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|||||||
Reference in New Issue
Block a user