From 9f6922bba968a236fe75af4f37ae5bc40b606af7 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 18 Aug 2025 16:07:45 +0800 Subject: [PATCH] support qwen-image-edit --- README.md | 3 + README_zh.md | 3 + diffsynth/models/qwen_image_dit.py | 4 +- diffsynth/pipelines/qwen_image.py | 77 ++++++++++++++++--- diffsynth/trainers/utils.py | 2 + examples/qwen_image/README.md | 3 + examples/qwen_image/README_zh.md | 3 + .../model_inference/Qwen-Image-Edit.py | 22 ++++++ .../Qwen-Image-Edit.py | 24 ++++++ .../model_training/full/Qwen-Image-Edit.sh | 14 ++++ .../model_training/lora/Qwen-Image-Edit.sh | 17 ++++ examples/qwen_image/model_training/train.py | 16 ++-- .../validate_full/Qwen-Image-Edit.py | 22 ++++++ .../validate_lora/Qwen-Image-Edit.py | 21 +++++ 14 files changed, 212 insertions(+), 19 deletions(-) create mode 100644 examples/qwen_image/model_inference/Qwen-Image-Edit.py create mode 100644 examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py create mode 100644 examples/qwen_image/model_training/full/Qwen-Image-Edit.sh create mode 100644 examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh create mode 100644 examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py create mode 100644 examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py diff --git a/README.md b/README.md index 71329ca..b6191c1 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,7 @@ image.save("image.jpg") |Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training| |-|-|-|-|-|-|-| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| +|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)| |[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| @@ -367,6 +368,8 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## Update History +- **August 19, 2025** 🔥 Qwen-Image-Edit is now open source. Welcome the new member to the image editing model family! + - **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset). This is an image dataset generated using the Qwen-Image model, with a total of 160,000 `1024 x 1024` images. It includes the general, English text rendering, and Chinese text rendering subsets. We provide caption, entity and control images annotations for each image. Developers can use this dataset to train models such as ControlNet and EliGen for the Qwen-Image model. We aim to promote technological development through open-source contributions! - **August 13, 2025** We trained and open-sourced the ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py). diff --git a/README_zh.md b/README_zh.md index 1596482..75a58d6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -92,6 +92,7 @@ image.save("image.jpg") |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| +|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)| |[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| @@ -383,6 +384,8 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## 更新历史 +- **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员! + - **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展! - **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。 diff --git a/diffsynth/models/qwen_image_dit.py b/diffsynth/models/qwen_image_dit.py index b60b0e4..137e4cd 100644 --- a/diffsynth/models/qwen_image_dit.py +++ b/diffsynth/models/qwen_image_dit.py @@ -63,8 +63,8 @@ class QwenEmbedRope(nn.Module): super().__init__() self.theta = theta self.axes_dim = axes_dim - pos_index = torch.arange(1024) - neg_index = torch.arange(1024).flip(0) * -1 - 1 + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 self.pos_freqs = torch.cat([ self.rope_params(pos_index, self.axes_dim[0], self.theta), self.rope_params(pos_index, self.axes_dim[1], self.theta), diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 3b529b2..ff02775 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -52,7 +52,7 @@ class QwenImagePipeline(BasePipeline): device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, ) - from transformers import Qwen2Tokenizer + from transformers import Qwen2Tokenizer, Qwen2VLProcessor self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02) self.text_encoder: QwenImageTextEncoder = None @@ -60,6 +60,7 @@ class QwenImagePipeline(BasePipeline): self.vae: QwenImageVAE = None self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None self.tokenizer: Qwen2Tokenizer = None + self.processor: Qwen2VLProcessor = None self.unit_runner = PipelineUnitRunner() self.in_iteration_models = ("dit", "blockwise_controlnet") self.units = [ @@ -69,6 +70,7 @@ class QwenImagePipeline(BasePipeline): QwenImageUnit_PromptEmbedder(), QwenImageUnit_EntityControl(), QwenImageUnit_BlockwiseControlNet(), + QwenImageUnit_EditImageEmbedder(), ] self.model_fn = model_fn_qwen_image @@ -218,6 +220,7 @@ class QwenImagePipeline(BasePipeline): device: Union[str, torch.device] = "cuda", model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + processor_config: ModelConfig = None, ): # Download and load models model_manager = ModelManager() @@ -239,6 +242,10 @@ class QwenImagePipeline(BasePipeline): tokenizer_config.download_if_necessary() from transformers import Qwen2Tokenizer pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path) + if processor_config is not None: + processor_config.download_if_necessary() + from transformers import Qwen2VLProcessor + pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path) return pipe @@ -266,6 +273,8 @@ class QwenImagePipeline(BasePipeline): eligen_entity_prompts: list[str] = None, eligen_entity_masks: list[Image.Image] = None, eligen_enable_on_negative: bool = False, + # Edit Image + edit_image: Image.Image = None, # FP8 enable_fp8_attention: bool = False, # Tile @@ -295,6 +304,7 @@ class QwenImagePipeline(BasePipeline): "blockwise_controlnet_inputs": blockwise_controlnet_inputs, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, + "edit_image": edit_image, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -366,13 +376,13 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit): return {"latents": latents, "input_latents": None} - class QwenImageUnit_PromptEmbedder(PipelineUnit): def __init__(self): super().__init__( seperate_cfg=True, input_params_posi={"prompt": "prompt"}, input_params_nega={"prompt": "negative_prompt"}, + input_params=("edit_image",), onload_model_names=("text_encoder",) ) @@ -383,18 +393,35 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit): split_result = torch.split(selected, valid_lengths.tolist(), dim=0) return split_result - def process(self, pipe: QwenImagePipeline, prompt) -> dict: + def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict: if pipe.text_encoder is not None: prompt = [prompt] - template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" - drop_idx = 34 + # If edit_image is None, use the default template for Qwen-Image, otherwise use the template for Qwen-Image-Edit + if edit_image is None: + template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 34 + else: + template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 64 txt = [template.format(e) for e in prompt] - txt_tokens = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) - if txt_tokens.input_ids.shape[1] >= 1024: - print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {txt_tokens['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.") - hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1] - - split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + + # Qwen-Image-Edit model + if pipe.processor is not None: + model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) + # Qwen-Image model + elif pipe.tokenizer is not None: + model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) + if model_inputs.input_ids.shape[1] >= 1024: + print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.") + else: + assert False, "QwenImagePipeline requires either tokenizer or processor to be loaded." + + if 'pixel_values' in model_inputs: + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] + else: + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1] + + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] max_seq_len = max([e.size(0) for e in split_hidden_states]) @@ -528,6 +555,23 @@ class QwenImageUnit_BlockwiseControlNet(PipelineUnit): return {"blockwise_controlnet_conditioning": conditionings} +class QwenImageUnit_EditImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process(self, pipe: QwenImagePipeline, edit_image, height, width, tiled, tile_size, tile_stride): + if edit_image is None: + return {} + edit_image = edit_image.resize((width, height)) + pipe.load_models_to_device(['vae']) + edit_image = pipe.preprocess_image(edit_image).to(device=pipe.device, dtype=pipe.torch_dtype) + edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return {"edit_latents": edit_latents} + + def model_fn_qwen_image( dit: QwenImageDiT = None, blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None, @@ -544,6 +588,7 @@ def model_fn_qwen_image( entity_prompt_emb=None, entity_prompt_emb_mask=None, entity_masks=None, + edit_latents=None, enable_fp8_attention=False, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, @@ -554,7 +599,13 @@ def model_fn_qwen_image( timestep = timestep / 1000 image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) - + + if edit_latents is not None: + img_shapes[0] = (img_shapes[0][0] + edit_latents.shape[0], img_shapes[0][1], img_shapes[0][2]) + edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + image_seq_len = image.shape[1] + image = torch.cat([image, edit_image], dim=1) + image = dit.img_in(image) conditioning = dit.time_text_embed(timestep, image.dtype) @@ -593,6 +644,8 @@ def model_fn_qwen_image( image = dit.norm_out(image, conditioning) image = dit.proj_out(image) + if edit_latents is not None: + image = image[:, :image_seq_len] latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) return latents diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 0187065..4b1aad6 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -552,4 +552,6 @@ def qwen_image_parser(): parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") + parser.add_argument("--edit_model", default=False, action="store_true", help="Whether to use Qwen-Image-Edit. If True, the model will be used for image editing.") + parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") return parser diff --git a/examples/qwen_image/README.md b/examples/qwen_image/README.md index 357ee7f..570d9cf 100644 --- a/examples/qwen_image/README.md +++ b/examples/qwen_image/README.md @@ -43,6 +43,7 @@ image.save("image.jpg") |Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training| |-|-|-|-|-|-|-| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)| +|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./model_inference/Qwen-Image-Edit.py)|[code](./model_inference_low_vram/Qwen-Image-Edit.py)|[code](./model_training/full/Qwen-Image-Edit.sh)|[code](./model_training/validate_full/Qwen-Image-Edit.py)|[code](./model_training/lora/Qwen-Image-Edit.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit.py)| |[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)| |[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)| @@ -235,6 +236,8 @@ The script includes the following parameters: * `--model_paths`: Model paths to load. In JSON format. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas. * `--tokenizer_path`: Tokenizer path. Leave empty to auto-download. + * `--edit_model`: Whether to use Qwen-Image-Edit. If True, the model will be used for image editing. + * `--processor_path`: Path to the processor of Qwen-Image-Edit. Leave empty to auto-download. * Training * `--learning_rate`: Learning rate. * `--weight_decay`: Weight decay. diff --git a/examples/qwen_image/README_zh.md b/examples/qwen_image/README_zh.md index 418bf25..26965b8 100644 --- a/examples/qwen_image/README_zh.md +++ b/examples/qwen_image/README_zh.md @@ -43,6 +43,7 @@ image.save("image.jpg") |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)| +|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./model_inference/Qwen-Image-Edit.py)|[code](./model_inference_low_vram/Qwen-Image-Edit.py)|[code](./model_training/full/Qwen-Image-Edit.sh)|[code](./model_training/validate_full/Qwen-Image-Edit.py)|[code](./model_training/lora/Qwen-Image-Edit.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit.py)| |[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)| |[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)| @@ -235,6 +236,8 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。 * `--tokenizer_path`: tokenizer 路径,留空将会自动下载。 + * `--edit_model`:是否使用 Qwen-Image-Edit。若为 True,则将使用该模型进行图像编辑。 + * `--processor_path`:Qwen-Image-Edit 的 processor 路径。留空则自动下载。 * 训练 * `--learning_rate`: 学习率。 * `--weight_decay`:权重衰减大小。 diff --git a/examples/qwen_image/model_inference/Qwen-Image-Edit.py b/examples/qwen_image/model_inference/Qwen-Image-Edit.py new file mode 100644 index 0000000..feabaff --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Edit.py @@ -0,0 +1,22 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=None, + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=1024) +image.save("image1.jpg") + +prompt = "将裙子改为粉色" +for seed in range(1, 10): + image = pipe(prompt, edit_image=image, seed=seed, num_inference_steps=40, height=1024, width=1024) + image.save(f"image2_{seed}.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py new file mode 100644 index 0000000..3c52157 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py @@ -0,0 +1,24 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ], + tokenizer_config=None, + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) +pipe.enable_vram_management() + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=1024) +image.save("image1.jpg") + +prompt = "将裙子改为粉色" +for seed in range(1, 10): + image = pipe(prompt, edit_image=image, seed=seed, num_inference_steps=40, height=1024, width=1024) + image.save(f"image2_{seed}.jpg") diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh b/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh new file mode 100644 index 0000000..fe99f01 --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \ + --edit_model \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_edit.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image-Edit:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Edit_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh new file mode 100644 index 0000000..ddde13f --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh @@ -0,0 +1,17 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --edit_model \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_edit.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image-Edit:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Edit_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 5f26927..5239797 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -11,7 +11,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): def __init__( self, model_paths=None, model_id_with_origin_paths=None, - tokenizer_path=None, + tokenizer_path=None, processor_path=None, edit_model=False, trainable_models=None, lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, use_gradient_checkpointing=True, @@ -27,11 +27,15 @@ class QwenImageTrainingModule(DiffusionTrainingModule): if model_id_with_origin_paths is not None: model_id_with_origin_paths = model_id_with_origin_paths.split(",") model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] - if tokenizer_path is not None: - self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=ModelConfig(tokenizer_path)) - else: - self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) + if edit_model: + tokenizer_config = None + processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) + else: + tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) + processor_config = None + self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) + # Reset training scheduler (do it in each training step) self.pipe.scheduler.set_timesteps(1000, training=True) @@ -115,6 +119,8 @@ if __name__ == "__main__": model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, tokenizer_path=args.tokenizer_path, + processor_path=args.processor_path, + edit_model=args.edit_model, trainable_models=args.trainable_models, lora_base_model=args.lora_base_model, lora_target_modules=args.lora_target_modules, diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py b/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py new file mode 100644 index 0000000..4d56ee8 --- /dev/null +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py @@ -0,0 +1,22 @@ +import torch +from PIL import Image +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth import load_state_dict + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=None, + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) +state_dict = load_state_dict("models/train/Qwen-Image-Edit_full/epoch-1.safetensors") + +prompt = "将裙子改为粉色" +image = Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)) +image = pipe(prompt, edit_image=image, seed=0, num_inference_steps=40, height=1024, width=1024) +image.save(f"image.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py new file mode 100644 index 0000000..2576be3 --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py @@ -0,0 +1,21 @@ +import torch +from PIL import Image +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=None, + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Edit_lora/epoch-4.safetensors") + +prompt = "将裙子改为粉色" +image = Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)) +image = pipe(prompt, edit_image=image, seed=0, num_inference_steps=40, height=1024, width=1024) +image.save(f"image.jpg")