support qwen-image-edit-2511

This commit is contained in:
Artiprocher
2025-12-18 19:16:52 +08:00
parent 3cb5cec906
commit 4629d4cf9e
9 changed files with 241 additions and 4 deletions

View File

@@ -352,9 +352,38 @@ class QwenImageTransformerBlock(nn.Module):
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim) self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
def _modulate(self, x, mod_params): def _modulate(self, x, mod_params, index=None):
shift, scale, gate = mod_params.chunk(3, dim=-1) shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) if index is not None:
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
# So shift, scale, gate have shape [2*actual_batch, d]
actual_batch = shift.size(0) // 2
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
# index: [b, l] where b is actual batch size
# Expand to [b, l, 1] to match feature dimension
index_expanded = index.unsqueeze(-1) # [b, l, 1]
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
scale_0_exp = scale_0.unsqueeze(1)
scale_1_exp = scale_1.unsqueeze(1)
gate_0_exp = gate_0.unsqueeze(1)
gate_1_exp = gate_1.unsqueeze(1)
# Use torch.where to select based on index
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
else:
shift_result = shift.unsqueeze(1)
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
return x * (1 + scale_result) + shift_result, gate_result
def forward( def forward(
self, self,
@@ -364,13 +393,16 @@ class QwenImageTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False, enable_fp8_attention = False,
modulate_index: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
if modulate_index is not None:
temb = torch.chunk(temb, 2, dim=0)[0]
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
img_normed = self.img_norm1(image) img_normed = self.img_norm1(image)
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn) img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index)
txt_normed = self.txt_norm1(text) txt_normed = self.txt_norm1(text)
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn) txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
@@ -387,7 +419,7 @@ class QwenImageTransformerBlock(nn.Module):
text = text + txt_gate * txt_attn_out text = text + txt_gate * txt_attn_out
img_normed_2 = self.img_norm2(image) img_normed_2 = self.img_norm2(image)
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp) img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index)
txt_normed_2 = self.txt_norm2(text) txt_normed_2 = self.txt_norm2(text)
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp) txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)

View File

@@ -4,6 +4,7 @@ from typing import Union
from tqdm import tqdm from tqdm import tqdm
from einops import rearrange from einops import rearrange
import numpy as np import numpy as np
from math import prod
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
@@ -125,6 +126,8 @@ class QwenImagePipeline(BasePipeline):
edit_image: Image.Image = None, edit_image: Image.Image = None,
edit_image_auto_resize: bool = True, edit_image_auto_resize: bool = True,
edit_rope_interpolation: bool = False, edit_rope_interpolation: bool = False,
# Qwen-Image-Edit-2511
zero_cond_t: bool = False,
# In-context control # In-context control
context_image: Image.Image = None, context_image: Image.Image = None,
# Tile # Tile
@@ -156,6 +159,7 @@ class QwenImagePipeline(BasePipeline):
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation, "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
"context_image": context_image, "context_image": context_image,
"zero_cond_t": zero_cond_t,
} }
for unit in self.units: for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -678,6 +682,7 @@ def model_fn_qwen_image(
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
edit_rope_interpolation=False, edit_rope_interpolation=False,
zero_cond_t=False,
**kwargs **kwargs
): ):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
@@ -698,6 +703,15 @@ def model_fn_qwen_image(
image = torch.cat([image] + edit_image, dim=1) image = torch.cat([image] + edit_image, dim=1)
image = dit.img_in(image) image = dit.img_in(image)
if zero_cond_t:
timestep = torch.cat([timestep, timestep * 0], dim=0)
modulate_index = torch.tensor(
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [img_shapes]],
device=timestep.device,
dtype=torch.int,
)
else:
modulate_index = None
conditioning = dit.time_text_embed(timestep, image.dtype) conditioning = dit.time_text_embed(timestep, image.dtype)
if entity_prompt_emb is not None: if entity_prompt_emb is not None:
@@ -728,6 +742,7 @@ def model_fn_qwen_image(
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask, attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention, enable_fp8_attention=enable_fp8_attention,
modulate_index=modulate_index,
) )
if blockwise_controlnet_conditioning is not None: if blockwise_controlnet_conditioning is not None:
image_slice = image[:, :image_seq_len].clone() image_slice = image[:, :image_seq_len].clone()
@@ -738,6 +753,8 @@ def model_fn_qwen_image(
) )
image[:, :image_seq_len] = image_slice + controlnet_output image[:, :image_seq_len] = image_slice + controlnet_output
if zero_cond_t:
conditioning = conditioning.chunk(2, dim=0)[0]
image = dit.norm_out(image, conditioning) image = dit.norm_out(image, conditioning)
image = dit.proj_out(image) image = dit.proj_out(image)
image = image[:, :image_seq_len] image = image[:, :image_seq_len]

View File

@@ -0,0 +1,44 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
dataset_snapshot_download(
"DiffSynth-Studio/example_image_dataset",
allow_file_pattern="qwen_image_edit/*",
local_dir="data/example_image_dataset",
)
prompt = "生成这两个人的合影"
edit_image = [
Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"),
Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"),
]
image = pipe(
prompt,
edit_image=edit_image,
seed=1,
num_inference_steps=40,
height=1152,
width=896,
edit_image_auto_resize=True,
zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511
)
image.save("image.jpg")
# Qwen-Image-Edit-2511 is a multi-image editing model.
# Please use a list to input `edit_image`, even if the input contains only one image.
# edit_image = [Image.open("image.jpg")]
# Please do not input the image directly.
# edit_image = Image.open("image.jpg")

View File

@@ -0,0 +1,54 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
dataset_snapshot_download(
"DiffSynth-Studio/example_image_dataset",
allow_file_pattern="qwen_image_edit/*",
local_dir="data/example_image_dataset",
)
prompt = "生成这两个人的合影"
edit_image = [
Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"),
Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"),
]
image = pipe(
prompt,
edit_image=edit_image,
seed=1,
num_inference_steps=40,
height=1152,
width=896,
edit_image_auto_resize=True,
zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511
)
image.save("image.jpg")
# Qwen-Image-Edit-2511 is a multi-image editing model.
# Please use a list to input `edit_image`, even if the input contains only one image.
# edit_image = [Image.open("image.jpg")]
# Please do not input the image directly.
# edit_image = Image.open("image.jpg")

View File

@@ -0,0 +1,16 @@
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2511:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Edit-2511_full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--zero_cond_t # This is a special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.

View File

@@ -0,0 +1,19 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2511:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Edit-2511_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters \
--zero_cond_t # This is a special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.

View File

@@ -20,6 +20,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
offload_models=None, offload_models=None,
device="cpu", device="cpu",
task="sft", task="sft",
zero_cond_t=False,
): ):
super().__init__() super().__init__()
# Load models # Load models
@@ -43,6 +44,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models self.fp8_models = fp8_models
self.task = task self.task = task
self.zero_cond_t = zero_cond_t
self.task_to_loss = { self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args, "sft:data_process": lambda pipe, *args: args,
"direct_distill:data_process": lambda pipe, *args: args, "direct_distill:data_process": lambda pipe, *args: args,
@@ -68,6 +70,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
"use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
"edit_image_auto_resize": True, "edit_image_auto_resize": True,
"zero_cond_t": self.zero_cond_t,
} }
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega return inputs_shared, inputs_posi, inputs_nega
@@ -87,6 +90,7 @@ def qwen_image_parser():
parser = add_image_size_config(parser) parser = add_image_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
parser.add_argument("--zero_cond_t", default=False, action="store_true", help="A special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.")
return parser return parser
@@ -130,6 +134,7 @@ if __name__ == "__main__":
offload_models=args.offload_models, offload_models=args.offload_models,
task=args.task, task=args.task,
device=accelerator.device, device=accelerator.device,
zero_cond_t=args.zero_cond_t,
) )
model_logger = ModelLogger( model_logger = ModelLogger(
args.output_path, args.output_path,

View File

@@ -0,0 +1,26 @@
import torch
from PIL import Image
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth import load_state_dict
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=None,
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
state_dict = load_state_dict("models/train/Qwen-Image-Edit-2511_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
images = [
Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
]
image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024, zero_cond_t=True)
image.save("image.jpg")

View File

@@ -0,0 +1,24 @@
import torch
from PIL import Image
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=None,
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Edit-2511_lora/epoch-4.safetensors")
prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
images = [
Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
]
image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024, zero_cond_t=True)
image.save("image.jpg")