From a3b4f235a05956075a6eefeaf7ed7e95a12ff267 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 11 Apr 2025 16:17:39 +0800 Subject: [PATCH] support reference image --- diffsynth/pipelines/flux_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index b98f34a..a165e13 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -702,7 +702,7 @@ def lets_dance_flux( # RoPE image_ids_ref = dit.prepare_image_ids(hidden_states_ref) idx = torch.arange(0, image_ids_ref.shape[0]).to(dtype=hidden_states.dtype, device=hidden_states.device) * 100 - image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype) + image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype, device=hidden_states.device) image_rotary_emb = torch.cat((image_rotary_emb, image_rotary_emb_ref), dim=2) # hidden_states original_hidden_states_length = hidden_states.shape[1]