support reference image

This commit is contained in:
Artiprocher
2025-04-11 16:17:39 +08:00
parent a572254a1d
commit a3b4f235a0

View File

@@ -702,7 +702,7 @@ def lets_dance_flux(
# RoPE # RoPE
image_ids_ref = dit.prepare_image_ids(hidden_states_ref) image_ids_ref = dit.prepare_image_ids(hidden_states_ref)
idx = torch.arange(0, image_ids_ref.shape[0]).to(dtype=hidden_states.dtype, device=hidden_states.device) * 100 idx = torch.arange(0, image_ids_ref.shape[0]).to(dtype=hidden_states.dtype, device=hidden_states.device) * 100
image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype) image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype, device=hidden_states.device)
image_rotary_emb = torch.cat((image_rotary_emb, image_rotary_emb_ref), dim=2) image_rotary_emb = torch.cat((image_rotary_emb, image_rotary_emb_ref), dim=2)
# hidden_states # hidden_states
original_hidden_states_length = hidden_states.shape[1] original_hidden_states_length = hidden_states.shape[1]