mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 02:38:10 +00:00
support reference image
This commit is contained in:
@@ -702,7 +702,7 @@ def lets_dance_flux(
|
|||||||
# RoPE
|
# RoPE
|
||||||
image_ids_ref = dit.prepare_image_ids(hidden_states_ref)
|
image_ids_ref = dit.prepare_image_ids(hidden_states_ref)
|
||||||
idx = torch.arange(0, image_ids_ref.shape[0]).to(dtype=hidden_states.dtype, device=hidden_states.device) * 100
|
idx = torch.arange(0, image_ids_ref.shape[0]).to(dtype=hidden_states.dtype, device=hidden_states.device) * 100
|
||||||
image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype)
|
image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
image_rotary_emb = torch.cat((image_rotary_emb, image_rotary_emb_ref), dim=2)
|
image_rotary_emb = torch.cat((image_rotary_emb, image_rotary_emb_ref), dim=2)
|
||||||
# hidden_states
|
# hidden_states
|
||||||
original_hidden_states_length = hidden_states.shape[1]
|
original_hidden_states_length = hidden_states.shape[1]
|
||||||
|
|||||||
Reference in New Issue
Block a user