This commit is contained in:
Artiprocher
2024-03-13 12:12:51 +08:00
parent 79fb9fe6c4
commit 7fc384fb7f
2 changed files with 258 additions and 0 deletions

View File

@@ -2,6 +2,7 @@ import torch
from .sd_unet import ResnetBlock, DownSampler
from .sd_vae_decoder import VAEAttentionBlock
from .tiler import TileWorker
from einops import rearrange
class SDVAEEncoder(torch.nn.Module):
@@ -73,6 +74,23 @@ class SDVAEEncoder(torch.nn.Module):
return hidden_states
def encode_video(self, sample, batch_size=8):
B = sample.shape[0]
hidden_states = []
for i in range(0, sample.shape[2], batch_size):
j = min(i + batch_size, sample.shape[2])
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
hidden_states_batch = self(sample_batch)
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
hidden_states.append(hidden_states_batch)
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
def state_dict_converter(self):
return SDVAEEncoderStateDictConverter()