rebuild base modules

This commit is contained in:
Artiprocher
2024-07-26 12:15:40 +08:00
parent 9471bff8a4
commit e3f8a576cf
76 changed files with 3253 additions and 3563 deletions

View File

@@ -22,6 +22,10 @@ def lets_dance(
device = "cuda",
vram_limit_level = 0,
):
# 0. Text embedding alignment (only for video processing)
if encoder_hidden_states.shape[0] != sample.shape[0]:
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
# 1. ControlNet
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
# I leave it here because I intend to do something interesting on the ControlNets.
@@ -50,7 +54,7 @@ def lets_dance(
additional_res_stack = None
# 2. time
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
time_emb = unet.time_proj(timestep).to(sample.dtype)
time_emb = unet.time_embedding(time_emb)
# 3. pre-process
@@ -133,7 +137,7 @@ def lets_dance_xl(
vram_limit_level = 0,
):
# 2. time
t_emb = unet.time_proj(timestep[None]).to(sample.dtype)
t_emb = unet.time_proj(timestep).to(sample.dtype)
t_emb = unet.time_embedding(t_emb)
time_embeds = unet.add_time_proj(add_time_id)