support qwen-image-edit-2511

This commit is contained in:
Artiprocher
2025-12-18 19:16:52 +08:00
parent 3cb5cec906
commit 4629d4cf9e
9 changed files with 241 additions and 4 deletions

View File

@@ -352,9 +352,38 @@ class QwenImageTransformerBlock(nn.Module):
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
def _modulate(self, x, mod_params):
def _modulate(self, x, mod_params, index=None):
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
if index is not None:
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
# So shift, scale, gate have shape [2*actual_batch, d]
actual_batch = shift.size(0) // 2
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
# index: [b, l] where b is actual batch size
# Expand to [b, l, 1] to match feature dimension
index_expanded = index.unsqueeze(-1) # [b, l, 1]
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
scale_0_exp = scale_0.unsqueeze(1)
scale_1_exp = scale_1.unsqueeze(1)
gate_0_exp = gate_0.unsqueeze(1)
gate_1_exp = gate_1.unsqueeze(1)
# Use torch.where to select based on index
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
else:
shift_result = shift.unsqueeze(1)
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
return x * (1 + scale_result) + shift_result, gate_result
def forward(
self,
@@ -364,13 +393,16 @@ class QwenImageTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False,
modulate_index: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
if modulate_index is not None:
temb = torch.chunk(temb, 2, dim=0)[0]
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
img_normed = self.img_norm1(image)
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index)
txt_normed = self.txt_norm1(text)
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
@@ -387,7 +419,7 @@ class QwenImageTransformerBlock(nn.Module):
text = text + txt_gate * txt_attn_out
img_normed_2 = self.img_norm2(image)
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index)
txt_normed_2 = self.txt_norm2(text)
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)

View File

@@ -4,6 +4,7 @@ from typing import Union
from tqdm import tqdm
from einops import rearrange
import numpy as np
from math import prod
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
@@ -125,6 +126,8 @@ class QwenImagePipeline(BasePipeline):
edit_image: Image.Image = None,
edit_image_auto_resize: bool = True,
edit_rope_interpolation: bool = False,
# Qwen-Image-Edit-2511
zero_cond_t: bool = False,
# In-context control
context_image: Image.Image = None,
# Tile
@@ -156,6 +159,7 @@ class QwenImagePipeline(BasePipeline):
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
"context_image": context_image,
"zero_cond_t": zero_cond_t,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -678,6 +682,7 @@ def model_fn_qwen_image(
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
edit_rope_interpolation=False,
zero_cond_t=False,
**kwargs
):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
@@ -698,6 +703,15 @@ def model_fn_qwen_image(
image = torch.cat([image] + edit_image, dim=1)
image = dit.img_in(image)
if zero_cond_t:
timestep = torch.cat([timestep, timestep * 0], dim=0)
modulate_index = torch.tensor(
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [img_shapes]],
device=timestep.device,
dtype=torch.int,
)
else:
modulate_index = None
conditioning = dit.time_text_embed(timestep, image.dtype)
if entity_prompt_emb is not None:
@@ -728,6 +742,7 @@ def model_fn_qwen_image(
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
modulate_index=modulate_index,
)
if blockwise_controlnet_conditioning is not None:
image_slice = image[:, :image_seq_len].clone()
@@ -738,6 +753,8 @@ def model_fn_qwen_image(
)
image[:, :image_seq_len] = image_slice + controlnet_output
if zero_cond_t:
conditioning = conditioning.chunk(2, dim=0)[0]
image = dit.norm_out(image, conditioning)
image = dit.proj_out(image)
image = image[:, :image_seq_len]