support video-to-video-translation

This commit is contained in:
Artiprocher
2023-12-21 17:11:58 +08:00
parent f7f4c1038e
commit c1453281df
20 changed files with 1659 additions and 427 deletions

View File

@@ -279,7 +279,7 @@ class SDUNet(torch.nn.Module):
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
def forward(self, sample, timestep, encoder_hidden_states, tiled=False, tile_size=64, tile_stride=8, **kwargs):
def forward(self, sample, timestep, encoder_hidden_states, tiled=False, tile_size=64, tile_stride=8, additional_res_stack=None, **kwargs):
# 1. time
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
@@ -293,6 +293,10 @@ class SDUNet(torch.nn.Module):
# 3. blocks
for i, block in enumerate(self.blocks):
if additional_res_stack is not None and i==31:
hidden_states += additional_res_stack.pop()
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
additional_res_stack = None
if tiled:
hidden_states, time_emb, text_emb, res_stack = self.tiled_inference(
block, hidden_states, time_emb, text_emb, res_stack,