mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support qwen-image-edit-2511
This commit is contained in:
@@ -352,9 +352,38 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
||||
|
||||
def _modulate(self, x, mod_params):
|
||||
def _modulate(self, x, mod_params, index=None):
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
if index is not None:
|
||||
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
|
||||
# So shift, scale, gate have shape [2*actual_batch, d]
|
||||
actual_batch = shift.size(0) // 2
|
||||
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
|
||||
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
|
||||
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
|
||||
|
||||
# index: [b, l] where b is actual batch size
|
||||
# Expand to [b, l, 1] to match feature dimension
|
||||
index_expanded = index.unsqueeze(-1) # [b, l, 1]
|
||||
|
||||
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
|
||||
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
|
||||
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
|
||||
scale_0_exp = scale_0.unsqueeze(1)
|
||||
scale_1_exp = scale_1.unsqueeze(1)
|
||||
gate_0_exp = gate_0.unsqueeze(1)
|
||||
gate_1_exp = gate_1.unsqueeze(1)
|
||||
|
||||
# Use torch.where to select based on index
|
||||
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
|
||||
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
|
||||
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
|
||||
else:
|
||||
shift_result = shift.unsqueeze(1)
|
||||
scale_result = scale.unsqueeze(1)
|
||||
gate_result = gate.unsqueeze(1)
|
||||
|
||||
return x * (1 + scale_result) + shift_result, gate_result
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -364,13 +393,16 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
enable_fp8_attention = False,
|
||||
modulate_index: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||
if modulate_index is not None:
|
||||
temb = torch.chunk(temb, 2, dim=0)[0]
|
||||
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||
|
||||
img_normed = self.img_norm1(image)
|
||||
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
|
||||
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index)
|
||||
|
||||
txt_normed = self.txt_norm1(text)
|
||||
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
|
||||
@@ -387,7 +419,7 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
text = text + txt_gate * txt_attn_out
|
||||
|
||||
img_normed_2 = self.img_norm2(image)
|
||||
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
|
||||
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index)
|
||||
|
||||
txt_normed_2 = self.txt_norm2(text)
|
||||
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
|
||||
|
||||
Reference in New Issue
Block a user