mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
fix flux-eligen bug
This commit is contained in:
@@ -375,8 +375,7 @@ class FluxDiT(torch.nn.Module):
|
||||
return attention_mask
|
||||
|
||||
|
||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
|
||||
repeat_dim = hidden_states.shape[1]
|
||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
|
||||
max_masks = 0
|
||||
attention_mask = None
|
||||
prompt_embs = [prompt_emb]
|
||||
|
||||
@@ -762,7 +762,7 @@ def lets_dance_flux(
|
||||
hidden_states = dit.x_embedder(hidden_states)
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, 16)
|
||||
else:
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
@@ -1233,7 +1233,7 @@ def model_fn_flux_image(
|
||||
|
||||
# EliGen
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, latents.shape[1])
|
||||
else:
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
Reference in New Issue
Block a user