mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 07:18:14 +00:00
SVD VAE
This commit is contained in:
@@ -2,6 +2,7 @@ import torch
|
||||
from .sd_unet import ResnetBlock, DownSampler
|
||||
from .sd_vae_decoder import VAEAttentionBlock
|
||||
from .tiler import TileWorker
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class SDVAEEncoder(torch.nn.Module):
|
||||
@@ -73,6 +74,23 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def encode_video(self, sample, batch_size=8):
|
||||
B = sample.shape[0]
|
||||
hidden_states = []
|
||||
|
||||
for i in range(0, sample.shape[2], batch_size):
|
||||
|
||||
j = min(i + batch_size, sample.shape[2])
|
||||
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
||||
|
||||
hidden_states_batch = self(sample_batch)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
||||
|
||||
hidden_states.append(hidden_states_batch)
|
||||
|
||||
hidden_states = torch.concat(hidden_states, dim=2)
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user