mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 08:40:47 +00:00
@@ -375,8 +375,7 @@ class FluxDiT(torch.nn.Module):
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
|
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
|
||||||
repeat_dim = hidden_states.shape[1]
|
|
||||||
max_masks = 0
|
max_masks = 0
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
prompt_embs = [prompt_emb]
|
prompt_embs = [prompt_emb]
|
||||||
|
|||||||
@@ -762,7 +762,7 @@ def lets_dance_flux(
|
|||||||
hidden_states = dit.x_embedder(hidden_states)
|
hidden_states = dit.x_embedder(hidden_states)
|
||||||
|
|
||||||
if entity_prompt_emb is not None and entity_masks is not None:
|
if entity_prompt_emb is not None and entity_masks is not None:
|
||||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, 16)
|
||||||
else:
|
else:
|
||||||
prompt_emb = dit.context_embedder(prompt_emb)
|
prompt_emb = dit.context_embedder(prompt_emb)
|
||||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||||
|
|||||||
@@ -1233,7 +1233,7 @@ def model_fn_flux_image(
|
|||||||
|
|
||||||
# EliGen
|
# EliGen
|
||||||
if entity_prompt_emb is not None and entity_masks is not None:
|
if entity_prompt_emb is not None and entity_masks is not None:
|
||||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, latents.shape[1])
|
||||||
else:
|
else:
|
||||||
prompt_emb = dit.context_embedder(prompt_emb)
|
prompt_emb = dit.context_embedder(prompt_emb)
|
||||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||||
|
|||||||
Reference in New Issue
Block a user