mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support eligenv2 and context_control
This commit is contained in:
@@ -467,6 +467,7 @@ class QwenImageDiT(torch.nn.Module):
|
||||
image_start = sum(seq_lens)
|
||||
image_end = total_seq_len
|
||||
cumsum = [0]
|
||||
single_image_seq = image_end - image_start
|
||||
for length in seq_lens:
|
||||
cumsum.append(cumsum[-1] + length)
|
||||
for i in range(N):
|
||||
@@ -474,6 +475,9 @@ class QwenImageDiT(torch.nn.Module):
|
||||
prompt_end = cumsum[i+1]
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
||||
# repeat image mask to match the single image sequence length
|
||||
repeat_time = single_image_seq // image_mask.shape[-1]
|
||||
image_mask = image_mask.repeat(1, 1, repeat_time)
|
||||
# prompt update with image
|
||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||
# image update with prompt
|
||||
@@ -493,7 +497,8 @@ class QwenImageDiT(torch.nn.Module):
|
||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
||||
|
||||
return all_prompt_emb, image_rotary_emb, attention_mask
|
||||
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
|
||||
Reference in New Issue
Block a user