support eligenv2 and context_control

This commit is contained in:
mi804
2025-08-20 22:48:34 +08:00
parent 9ec0652339
commit 5e6f9f89f1
12 changed files with 371 additions and 2 deletions

View File

@@ -467,6 +467,7 @@ class QwenImageDiT(torch.nn.Module):
image_start = sum(seq_lens)
image_end = total_seq_len
cumsum = [0]
single_image_seq = image_end - image_start
for length in seq_lens:
cumsum.append(cumsum[-1] + length)
for i in range(N):
@@ -474,6 +475,9 @@ class QwenImageDiT(torch.nn.Module):
prompt_end = cumsum[i+1]
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
# repeat image mask to match the single image sequence length
repeat_time = single_image_seq // image_mask.shape[-1]
image_mask = image_mask.repeat(1, 1, repeat_time)
# prompt update with image
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
# image update with prompt
@@ -493,7 +497,8 @@ class QwenImageDiT(torch.nn.Module):
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
return all_prompt_emb, image_rotary_emb, attention_mask
def forward(
self,
latents=None,