import torch, math import torch.nn as nn from typing import Tuple, Optional, Union, List from einops import rearrange from .sd3_dit import TimestepEmbeddings, RMSNorm from .flux_dit import AdaLayerNorm try: import flash_attn_interface FLASH_ATTN_3_AVAILABLE = True except ModuleNotFoundError: FLASH_ATTN_3_AVAILABLE = False def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False): if FLASH_ATTN_3_AVAILABLE and attention_mask is None: if not enable_fp8_attention: q = rearrange(q, "b n s d -> b s n d", n=num_heads) k = rearrange(k, "b n s d -> b s n d", n=num_heads) v = rearrange(v, "b n s d -> b s n d", n=num_heads) x = flash_attn_interface.flash_attn_func(q, k, v) if isinstance(x, tuple): x = x[0] x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) else: origin_dtype = q.dtype q_std, k_std, v_std = q.std(), k.std(), v.std() q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn) q = rearrange(q, "b n s d -> b s n d", n=num_heads) k = rearrange(k, "b n s d -> b s n d", n=num_heads) v = rearrange(v, "b n s d -> b s n d", n=num_heads) x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1))) if isinstance(x, tuple): x = x[0] x = x.to(origin_dtype) * v_std x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) else: x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask) x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) return x class ApproximateGELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() self.proj = nn.Linear(dim_in, dim_out, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x * torch.sigmoid(1.702 * x) def apply_rotary_emb_qwen( x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] ): x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) return x_out.type_as(x) class QwenEmbedRope(nn.Module): def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): super().__init__() self.theta = theta self.axes_dim = axes_dim pos_index = torch.arange(1024) neg_index = torch.arange(1024).flip(0) * -1 - 1 self.pos_freqs = torch.cat([ self.rope_params(pos_index, self.axes_dim[0], self.theta), self.rope_params(pos_index, self.axes_dim[1], self.theta), self.rope_params(pos_index, self.axes_dim[2], self.theta), ], dim=1) self.neg_freqs = torch.cat([ self.rope_params(neg_index, self.axes_dim[0], self.theta), self.rope_params(neg_index, self.axes_dim[1], self.theta), self.rope_params(neg_index, self.axes_dim[2], self.theta), ], dim=1) self.rope_cache = {} self.scale_rope = scale_rope def rope_params(self, index, dim, theta=10000): """ Args: index: [0, 1, 2, 3] 1D Tensor representing the position index of the token """ assert dim % 2 == 0 freqs = torch.outer( index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)) ) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs def forward(self, video_fhw, txt_seq_lens, device): if self.pos_freqs.device != device: self.pos_freqs = self.pos_freqs.to(device) self.neg_freqs = self.neg_freqs.to(device) if isinstance(video_fhw, list): video_fhw = video_fhw[0] frame, height, width = video_fhw rope_key = f"{frame}_{height}_{width}" if rope_key not in self.rope_cache: seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: freqs_height = torch.cat( [ freqs_neg[1][-(height - height//2):], freqs_pos[1][:height//2] ], dim=0 ) freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = torch.cat( [ freqs_neg[2][-(width - width//2):], freqs_pos[2][:width//2] ], dim=0 ) freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) else: freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) self.rope_cache[rope_key] = freqs.clone().contiguous() vid_freqs = self.rope_cache[rope_key] if self.scale_rope: max_vid_index = max(height // 2, width // 2) else: max_vid_index = max(height, width) max_len = max(txt_seq_lens) txt_freqs = self.pos_freqs[max_vid_index: max_vid_index + max_len, ...] return vid_freqs, txt_freqs class QwenFeedForward(nn.Module): def __init__( self, dim: int, dim_out: Optional[int] = None, dropout: float = 0.0, ): super().__init__() inner_dim = int(dim * 4) self.net = nn.ModuleList([]) self.net.append(ApproximateGELU(dim, inner_dim)) self.net.append(nn.Dropout(dropout)) self.net.append(nn.Linear(inner_dim, dim_out)) def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: for module in self.net: hidden_states = module(hidden_states) return hidden_states class QwenDoubleStreamAttention(nn.Module): def __init__( self, dim_a, dim_b, num_heads, head_dim, ): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.to_q = nn.Linear(dim_a, dim_a) self.to_k = nn.Linear(dim_a, dim_a) self.to_v = nn.Linear(dim_a, dim_a) self.norm_q = RMSNorm(head_dim, eps=1e-6) self.norm_k = RMSNorm(head_dim, eps=1e-6) self.add_q_proj = nn.Linear(dim_b, dim_b) self.add_k_proj = nn.Linear(dim_b, dim_b) self.add_v_proj = nn.Linear(dim_b, dim_b) self.norm_added_q = RMSNorm(head_dim, eps=1e-6) self.norm_added_k = RMSNorm(head_dim, eps=1e-6) self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a)) self.to_add_out = nn.Linear(dim_b, dim_b) def forward( self, image: torch.FloatTensor, text: torch.FloatTensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, enable_fp8_attention: bool = False, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image) txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text) seq_txt = txt_q.shape[1] img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads) img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads) img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads) txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads) txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads) txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads) img_q, img_k = self.norm_q(img_q), self.norm_k(img_k) txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k) if image_rotary_emb is not None: img_freqs, txt_freqs = image_rotary_emb img_q = apply_rotary_emb_qwen(img_q, img_freqs) img_k = apply_rotary_emb_qwen(img_k, img_freqs) txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs) txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs) joint_q = torch.cat([txt_q, img_q], dim=2) joint_k = torch.cat([txt_k, img_k], dim=2) joint_v = torch.cat([txt_v, img_v], dim=2) joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype) txt_attn_output = joint_attn_out[:, :seq_txt, :] img_attn_output = joint_attn_out[:, seq_txt:, :] img_attn_output = self.to_out(img_attn_output) txt_attn_output = self.to_add_out(txt_attn_output) return img_attn_output, txt_attn_output class QwenImageTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, eps: float = 1e-6, ): super().__init__() self.dim = dim self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim self.img_mod = nn.Sequential( nn.SiLU(), nn.Linear(dim, 6 * dim), ) self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.attn = QwenDoubleStreamAttention( dim_a=dim, dim_b=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, ) self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim) self.txt_mod = nn.Sequential( nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True), ) self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim) def _modulate(self, x, mod_params): shift, scale, gate = mod_params.chunk(3, dim=-1) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) def forward( self, image: torch.Tensor, text: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, enable_fp8_attention = False, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each img_normed = self.img_norm1(image) img_modulated, img_gate = self._modulate(img_normed, img_mod_attn) txt_normed = self.txt_norm1(text) txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn) img_attn_out, txt_attn_out = self.attn( image=img_modulated, text=txt_modulated, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention, ) image = image + img_gate * img_attn_out text = text + txt_gate * txt_attn_out img_normed_2 = self.img_norm2(image) img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp) txt_normed_2 = self.txt_norm2(text) txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp) img_mlp_out = self.img_mlp(img_modulated_2) txt_mlp_out = self.txt_mlp(txt_modulated_2) image = image + img_gate_2 * img_mlp_out text = text + txt_gate_2 * txt_mlp_out return text, image class QwenImageDiT(torch.nn.Module): def __init__( self, num_layers: int = 60, ): super().__init__() self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True) self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True) self.txt_norm = RMSNorm(3584, eps=1e-6) self.img_in = nn.Linear(64, 3072) self.txt_in = nn.Linear(3584, 3072) self.transformer_blocks = nn.ModuleList( [ QwenImageTransformerBlock( dim=3072, num_attention_heads=24, attention_head_dim=128, ) for _ in range(num_layers) ] ) self.norm_out = AdaLayerNorm(3072, single=True) self.proj_out = nn.Linear(3072, 64) def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes): # prompt_emb all_prompt_emb = entity_prompt_emb + [prompt_emb] all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb] all_prompt_emb = torch.cat(all_prompt_emb, dim=1) # image_rotary_emb txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask] entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens] txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) # attention_mask repeat_dim = latents.shape[1] max_masks = entity_masks.shape[1] entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype) entity_masks = entity_masks + [global_mask] N = len(entity_masks) batch_size = entity_masks[0].shape[0] seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()] total_seq_len = sum(seq_lens) + image.shape[1] patched_masks = [] for i in range(N): patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) patched_masks.append(patched_mask) attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) # prompt-image attention mask image_start = sum(seq_lens) image_end = total_seq_len cumsum = [0] for length in seq_lens: cumsum.append(cumsum[-1] + length) for i in range(N): prompt_start = cumsum[i] prompt_end = cumsum[i+1] image_mask = torch.sum(patched_masks[i], dim=-1) > 0 image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1) # prompt update with image attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask # image update with prompt attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) # prompt-prompt attention mask, let the prompt tokens not attend to each other for i in range(N): for j in range(N): if i == j: continue start_i, end_i = cumsum[i], cumsum[i+1] start_j, end_j = cumsum[j], cumsum[j+1] attention_mask[:, start_i:end_i, start_j:end_j] = False attention_mask = attention_mask.float() attention_mask[attention_mask == 0] = float('-inf') attention_mask[attention_mask == 1] = 0 attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) return all_prompt_emb, image_rotary_emb, attention_mask def forward( self, latents=None, timestep=None, prompt_emb=None, prompt_emb_mask=None, height=None, width=None, ): img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2) image = self.img_in(image) text = self.txt_in(self.txt_norm(prompt_emb)) conditioning = self.time_text_embed(timestep, image.dtype) image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) for block in self.transformer_blocks: text, image = block( image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb, ) image = self.norm_out(image, conditioning) image = self.proj_out(image) latents = rearrange(image, "B (H W) (P Q C) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) return image @staticmethod def state_dict_converter(): return QwenImageDiTStateDictConverter() class QwenImageDiTStateDictConverter(): def __init__(self): pass def from_civitai(self, state_dict): return state_dict