This commit is contained in:
Artiprocher
2024-10-23 14:24:41 +08:00
parent 07d70a6a56
commit aa054db1c7

View File

@@ -139,6 +139,8 @@ def lets_dance_xl(
# 0. Text embedding alignment (only for video processing)
if encoder_hidden_states.shape[0] != sample.shape[0]:
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
if add_text_embeds.shape[0] != sample.shape[0]:
add_text_embeds = add_text_embeds.repeat(sample.shape[0], 1)
# 1. ControlNet
controlnet_insert_block_id = 22
@@ -204,7 +206,7 @@ def lets_dance_xl(
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb,
time_emb[batch_id: batch_id_],
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention,