This commit is contained in:
Artiprocher
2023-12-30 21:01:24 +08:00
parent b9771db163
commit d24ddaacaa
19 changed files with 2252 additions and 34 deletions

View File

@@ -15,6 +15,7 @@ def lets_dance(
controlnet_frames = None,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
tiled=False,
tile_size=64,
tile_stride=32,
@@ -86,7 +87,13 @@ def lets_dance(
tile_dtype=hidden_states.dtype
)
else:
hidden_states, _, _, _ = block(hidden_states_input[batch_id: batch_id_], time_emb, text_emb[batch_id: batch_id_], res_stack)
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb,
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention
)
hidden_states_output.append(hidden_states)
hidden_states = torch.concat(hidden_states_output, dim=0)
# 4.2 AnimateDiff