Merge pull request #756 from mi804/flux-eligen

fix flux-eligen bug
This commit is contained in:
Zhongjie Duan
2025-08-06 20:09:00 +08:00
committed by GitHub
3 changed files with 3 additions and 4 deletions

View File

@@ -375,8 +375,7 @@ class FluxDiT(torch.nn.Module):
return attention_mask return attention_mask
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids): def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
repeat_dim = hidden_states.shape[1]
max_masks = 0 max_masks = 0
attention_mask = None attention_mask = None
prompt_embs = [prompt_emb] prompt_embs = [prompt_emb]

View File

@@ -762,7 +762,7 @@ def lets_dance_flux(
hidden_states = dit.x_embedder(hidden_states) hidden_states = dit.x_embedder(hidden_states)
if entity_prompt_emb is not None and entity_masks is not None: if entity_prompt_emb is not None and entity_masks is not None:
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, 16)
else: else:
prompt_emb = dit.context_embedder(prompt_emb) prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))

View File

@@ -1233,7 +1233,7 @@ def model_fn_flux_image(
# EliGen # EliGen
if entity_prompt_emb is not None and entity_masks is not None: if entity_prompt_emb is not None and entity_masks is not None:
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, latents.shape[1])
else: else:
prompt_emb = dit.context_embedder(prompt_emb) prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))