Merge branch 'main' into fp8

This commit is contained in:
Zhongjie Duan
2025-08-07 16:40:44 +08:00
committed by GitHub
10 changed files with 731 additions and 11 deletions

View File

@@ -12,7 +12,7 @@ except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, enable_fp8_attention: bool = False):
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
if FLASH_ATTN_3_AVAILABLE:
if not enable_fp8_attention:
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
@@ -35,7 +35,7 @@ def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
x = x.to(origin_dtype) * v_std
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
@@ -193,6 +193,7 @@ class QwenDoubleStreamAttention(nn.Module):
image: torch.FloatTensor,
text: torch.FloatTensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
enable_fp8_attention: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
@@ -221,7 +222,7 @@ class QwenDoubleStreamAttention(nn.Module):
joint_k = torch.cat([txt_k, img_k], dim=2)
joint_v = torch.cat([txt_v, img_v], dim=2)
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
txt_attn_output = joint_attn_out[:, :seq_txt, :]
img_attn_output = joint_attn_out[:, seq_txt:, :]
@@ -278,6 +279,7 @@ class QwenImageTransformerBlock(nn.Module):
text: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -294,7 +296,8 @@ class QwenImageTransformerBlock(nn.Module):
image=img_modulated,
text=txt_modulated,
image_rotary_emb=image_rotary_emb,
enable_fp8_attention=enable_fp8_attention,
attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False,
)
image = image + img_gate * img_attn_out
@@ -344,6 +347,69 @@ class QwenImageDiT(torch.nn.Module):
self.proj_out = nn.Linear(3072, 64)
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes):
# prompt_emb
all_prompt_emb = entity_prompt_emb + [prompt_emb]
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
# image_rotary_emb
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask]
entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
# attention_mask
repeat_dim = latents.shape[1]
max_masks = entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype)
entity_masks = entity_masks + [global_mask]
N = len(entity_masks)
batch_size = entity_masks[0].shape[0]
seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
total_seq_len = sum(seq_lens) + image.shape[1]
patched_masks = []
for i in range(N):
patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
patched_masks.append(patched_mask)
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
# prompt-image attention mask
image_start = sum(seq_lens)
image_end = total_seq_len
cumsum = [0]
for length in seq_lens:
cumsum.append(cumsum[-1] + length)
for i in range(N):
prompt_start = cumsum[i]
prompt_end = cumsum[i+1]
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
# prompt update with image
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
# image update with prompt
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
# prompt-prompt attention mask, let the prompt tokens not attend to each other
for i in range(N):
for j in range(N):
if i == j:
continue
start_i, end_i = cumsum[i], cumsum[i+1]
start_j, end_j = cumsum[j], cumsum[j+1]
attention_mask[:, start_i:end_i, start_j:end_j] = False
attention_mask = attention_mask.float()
attention_mask[attention_mask == 0] = float('-inf')
attention_mask[attention_mask == 1] = 0
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
return all_prompt_emb, image_rotary_emb, attention_mask
def forward(
self,
latents=None,