support reference image

This commit is contained in:
Artiprocher
2025-04-11 16:14:24 +08:00
parent 9e78bf5e89
commit a572254a1d
4 changed files with 15 additions and 9 deletions

View File

@@ -694,7 +694,7 @@ def lets_dance_flux(
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1), device=hidden_states.device)
attention_mask = None
# Reference images