From 970403f78e8a2c383d6744ec7a136babde67fbdb Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 6 Aug 2025 20:07:21 +0800 Subject: [PATCH] fix flux-eligen bug --- diffsynth/models/flux_dit.py | 3 +-- diffsynth/pipelines/flux_image.py | 2 +- diffsynth/pipelines/flux_image_new.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 0ec9b07..411ac9c 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -375,8 +375,7 @@ class FluxDiT(torch.nn.Module): return attention_mask - def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids): - repeat_dim = hidden_states.shape[1] + def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim): max_masks = 0 attention_mask = None prompt_embs = [prompt_emb] diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 90b196d..55a84c0 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -762,7 +762,7 @@ def lets_dance_flux( hidden_states = dit.x_embedder(hidden_states) if entity_prompt_emb is not None and entity_masks is not None: - prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) + prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, 16) else: prompt_emb = dit.context_embedder(prompt_emb) image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 9384624..63a9dff 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -1233,7 +1233,7 @@ def model_fn_flux_image( # EliGen if entity_prompt_emb is not None and entity_masks is not None: - prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) + prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, latents.shape[1]) else: prompt_emb = dit.context_embedder(prompt_emb) image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))