Compare commits

..

4 Commits

Author SHA1 Message Date
mi804
b2d4bc8dd8 block wise controlnet 2025-08-12 13:10:47 +08:00
Artiprocher
c8ea3caf39 bugfix 2025-08-08 12:49:59 +08:00
Artiprocher
0d519ee08a bugfix 2025-08-08 12:47:04 +08:00
Artiprocher
6e13deb6de qwen-image controlnet 2025-08-08 11:29:23 +08:00
59 changed files with 387 additions and 1300 deletions

View File

@@ -87,16 +87,11 @@ image.save("image.jpg")
<summary>Model Overview</summary>
|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
</details>
### FLUX Series
@@ -368,16 +363,6 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
## Update History
- **August 18, 2025** We trained and open-sourced the Inpaint ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
- **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset). This is an image dataset generated using the Qwen-Image model, with a total of 160,000 `1024 x 1024` images. It includes the general, English text rendering, and Chinese text rendering subsets. We provide caption, entity and control images annotations for each image. Developers can use this dataset to train models such as ControlNet and EliGen for the Qwen-Image model. We aim to promote technological development through open-source contributions!
- **August 13, 2025** We trained and open-sourced the ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py).
- **August 12, 2025** We trained and open-sourced the ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py).
- **August 11, 2025** We released another distilled acceleration model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA). It uses the same training process as [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), but the model structure is changed to LoRA. This makes it work better with other open-source models.
- **August 7, 2025** We open-sourced the entity control LoRA of Qwen-Image, [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen). Qwen-Image-EliGen is able to achieve entity-level controlled text-to-image generation. See the [paper](https://arxiv.org/abs/2501.01097) for technical details. Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet).
- **August 5, 2025** We open-sourced the distilled acceleration model of Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), achieving approximately 5x speedup.

View File

@@ -89,15 +89,11 @@ image.save("image.jpg")
<summary>模型总览</summary>
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
</details>
@@ -384,16 +380,6 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
## 更新历史
- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
- **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。
- **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。
- **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA因此能够更好地与其他开源生态模型兼容。
- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。

View File

@@ -75,6 +75,7 @@ from ..models.nexus_gen import NexusGenAutoregressiveModel
from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE
from ..models.qwen_image_controlnet import QwenImageControlNet
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
model_loader_configs = [
@@ -168,8 +169,8 @@ model_loader_configs = [
(None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
(None, "be2500a62936a43d5367a70ea001e25d", ["qwen_image_controlnet"], [QwenImageControlNet], "civitai"),
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
(None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.

View File

@@ -1,7 +1,98 @@
import torch
import torch.nn as nn
from .sd3_dit import RMSNorm
from .utils import hash_state_dict_keys
from .qwen_image_dit import QwenEmbedRope, QwenImageTransformerBlock
from ..vram_management import gradient_checkpoint_forward
from einops import rearrange
from .sd3_dit import TimestepEmbeddings, RMSNorm
class QwenImageControlNet(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
num_controlnet_layers: int = 6,
):
super().__init__()
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
self.txt_norm = RMSNorm(3584, eps=1e-6)
self.img_in = nn.Linear(64 * 2, 3072)
self.txt_in = nn.Linear(3584, 3072)
self.transformer_blocks = nn.ModuleList(
[
QwenImageTransformerBlock(
dim=3072,
num_attention_heads=24,
attention_head_dim=128,
)
for _ in range(num_controlnet_layers)
]
)
self.proj_out = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for i in range(num_layers)])
self.num_layers = num_layers
self.num_controlnet_layers = num_controlnet_layers
self.align_map = {i: i // (num_layers // num_controlnet_layers) for i in range(num_layers)}
def forward(
self,
latents=None,
timestep=None,
prompt_emb=None,
prompt_emb_mask=None,
height=None,
width=None,
controlnet_conditioning=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
controlnet_conditioning = rearrange(controlnet_conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image = torch.concat([image, controlnet_conditioning], dim=-1)
image = self.img_in(image)
text = self.txt_in(self.txt_norm(prompt_emb))
conditioning = self.time_text_embed(timestep, image.dtype)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
outputs = []
for block in self.transformer_blocks:
text, image = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
)
outputs.append(image)
outputs_aligned = [self.proj_out[i](outputs[self.align_map[i]]) for i in range(self.num_layers)]
return outputs_aligned
@staticmethod
def state_dict_converter():
return QwenImageControlNetStateDictConverter()
class QwenImageControlNetStateDictConverter():
def __init__(self):
pass
def from_civitai(self, state_dict):
return state_dict
class BlockWiseControlBlock(torch.nn.Module):
@@ -13,7 +104,7 @@ class BlockWiseControlBlock(torch.nn.Module):
self.input_proj = nn.Linear(dim, dim)
self.act = nn.GELU()
self.output_proj = nn.Linear(dim, dim)
def forward(self, x, y):
x, y = self.x_rms(x), self.y_rms(y)
x = self.input_proj(x + y)
@@ -32,11 +123,10 @@ class QwenImageBlockWiseControlNet(torch.nn.Module):
self,
num_layers: int = 60,
in_dim: int = 64,
additional_in_dim: int = 0,
dim: int = 3072,
):
super().__init__()
self.img_in = nn.Linear(in_dim + additional_in_dim, dim)
self.img_in = nn.Linear(in_dim, dim)
self.controlnet_blocks = nn.ModuleList(
[
BlockWiseControlBlock(dim)
@@ -66,9 +156,4 @@ class QwenImageBlockWiseControlNetStateDictConverter():
pass
def from_civitai(self, state_dict):
hash_value = hash_state_dict_keys(state_dict)
extra_kwargs = {}
if hash_value == "a9e54e480a628f0b956a688a81c33bab":
# inpaint controlnet
extra_kwargs = {"additional_in_dim": 4}
return state_dict, extra_kwargs
return state_dict

View File

@@ -90,39 +90,8 @@ class QwenEmbedRope(nn.Module):
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
_, height, width = video_fhw
if self.scale_rope:
max_vid_index = max(height // 2, width // 2)
else:
max_vid_index = max(height, width)
required_len = max_vid_index + max(txt_seq_lens)
cur_max_len = self.pos_freqs.shape[0]
if required_len <= cur_max_len:
return
new_max_len = math.ceil(required_len / 512) * 512
pos_index = torch.arange(new_max_len)
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
self.pos_freqs = torch.cat([
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
], dim=1)
self.neg_freqs = torch.cat([
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
], dim=1)
return
def forward(self, video_fhw, txt_seq_lens, device):
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)

View File

@@ -10,7 +10,7 @@ from ..models import ModelManager, load_state_dict
from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
from ..models.qwen_image_controlnet import QwenImageControlNet, QwenImageBlockWiseControlNet
from ..schedulers import FlowMatchScheduler
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora import GeneralLoRALoader
@@ -19,30 +19,38 @@ from .flux_image_new import ControlNetInput
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
class QwenImageBlockwiseMultiControlNet(torch.nn.Module):
def __init__(self, models: list[QwenImageBlockWiseControlNet]):
class QwenImageMultiControlNet(torch.nn.Module):
def __init__(self, models: list[QwenImageControlNet]):
super().__init__()
if not isinstance(models, list):
models = [models]
self.models = torch.nn.ModuleList(models)
def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs):
processed_conditionings = []
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
conditioning = rearrange(conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning)
processed_conditionings.append(model_output)
return processed_conditionings
def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs):
model = self.models[controlnet_input.controlnet_id]
res_stack = model(
controlnet_conditioning=conditioning,
processor_id=controlnet_input.processor_id,
**kwargs
)
res_stack = [res * controlnet_input.scale for res in res_stack]
return res_stack
def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs):
res = 0
def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs):
res_stack = None
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4):
continue
model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id)
res = res + model_output * controlnet_input.scale
return res
res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs)
if res_stack is None:
res_stack = res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack
class QwenImagePipeline(BasePipeline):
@@ -58,18 +66,17 @@ class QwenImagePipeline(BasePipeline):
self.text_encoder: QwenImageTextEncoder = None
self.dit: QwenImageDiT = None
self.vae: QwenImageVAE = None
self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None
self.controlnet: QwenImageMultiControlNet = None
self.tokenizer: Qwen2Tokenizer = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit", "blockwise_controlnet")
self.in_iteration_models = ("dit", "controlnet", "blockwise_controlnet")
self.units = [
QwenImageUnit_ShapeChecker(),
QwenImageUnit_NoiseInitializer(),
QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_Inpaint(),
QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(),
QwenImageUnit_BlockwiseControlNet(),
QwenImageUnit_ControlNet(),
]
self.model_fn = model_fn_qwen_image
@@ -194,23 +201,6 @@ class QwenImagePipeline(BasePipeline):
),
vram_limit=vram_limit,
)
if self.blockwise_controlnet is not None:
enable_vram_management(
self.blockwise_controlnet,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
@staticmethod
@@ -235,7 +225,8 @@ class QwenImagePipeline(BasePipeline):
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
pipe.dit = model_manager.fetch_model("qwen_image_dit")
pipe.vae = model_manager.fetch_model("qwen_image_vae")
pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_manager.fetch_model("qwen_image_blockwise_controlnet", index="all"))
pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all"))
pipe.blockwise_controlnet = model_manager.fetch_model("qwen_image_blockwise_controlnet")
if tokenizer_config is not None and pipe.text_encoder is not None:
tokenizer_config.download_if_necessary()
from transformers import Qwen2Tokenizer
@@ -253,10 +244,6 @@ class QwenImagePipeline(BasePipeline):
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Inpaint
inpaint_mask: Image.Image = None,
inpaint_blur_size: int = None,
inpaint_blur_sigma: float = None,
# Shape
height: int = 1328,
width: int = 1328,
@@ -265,8 +252,8 @@ class QwenImagePipeline(BasePipeline):
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
# Blockwise ControlNet
blockwise_controlnet_inputs: list[ControlNetInput] = None,
# ControlNet
controlnet_inputs: list[ControlNetInput] = None,
# EliGen
eligen_entity_prompts: list[str] = None,
eligen_entity_masks: list[Image.Image] = None,
@@ -279,7 +266,6 @@ class QwenImagePipeline(BasePipeline):
tile_stride: int = 64,
# Progress bar
progress_bar_cmd = tqdm,
extra_prompt_emb = None,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16))
@@ -294,20 +280,16 @@ class QwenImagePipeline(BasePipeline):
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"enable_fp8_attention": enable_fp8_attention,
"num_inference_steps": num_inference_steps,
"blockwise_controlnet_inputs": blockwise_controlnet_inputs,
"controlnet_inputs": controlnet_inputs,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
if extra_prompt_emb is not None:
inputs_posi["prompt_emb"] = torch.concat([inputs_posi["prompt_emb"], extra_prompt_emb], dim=1)
inputs_posi["prompt_emb_mask"] = torch.ones((1, inputs_posi["prompt_emb"].shape[1]), dtype=inputs_posi["prompt_emb_mask"].dtype, device=inputs_posi["prompt_emb_mask"].device)
# Denoise
self.load_models_to_device(self.in_iteration_models)
@@ -324,7 +306,7 @@ class QwenImagePipeline(BasePipeline):
noise_pred = noise_pred_posi
# Scheduler
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
# Decode
self.load_models_to_device(['vae'])
@@ -373,26 +355,7 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
class QwenImageUnit_Inpaint(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"),
)
def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma):
if inpaint_mask is None:
return {}
inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 8, height // 8)), min_value=0, max_value=1)
inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True)
if inpaint_blur_size is not None and inpaint_blur_sigma is not None:
from torchvision.transforms import GaussianBlur
blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma)
inpaint_mask = blur(inpaint_mask)
return {"inpaint_mask": inpaint_mask}
return {"latents": latents, "input_latents": None}
@@ -418,9 +381,7 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
drop_idx = 34
txt = [template.format(e) for e in prompt]
txt_tokens = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
if txt_tokens.input_ids.shape[1] >= 1024:
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {txt_tokens['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
@@ -439,7 +400,7 @@ class QwenImageUnit_EntityControl(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("text_encoder",)
onload_model_names=("text_encoder")
)
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
@@ -514,11 +475,10 @@ class QwenImageUnit_EntityControl(PipelineUnit):
return inputs_shared, inputs_posi, inputs_nega
class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
class QwenImageUnit_ControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("blockwise_controlnet_inputs", "tiled", "tile_size", "tile_stride"),
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
@@ -537,12 +497,13 @@ class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
image = Image.fromarray(image)
return image
def process(self, pipe: QwenImagePipeline, blockwise_controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
if blockwise_controlnet_inputs is None:
def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
if controlnet_inputs is None:
return {}
return_key = "blockwise_controlnet_conditioning" if pipe.blockwise_controlnet is not None else "controlnet_conditionings"
pipe.load_models_to_device(self.onload_model_names)
conditionings = []
for controlnet_input in blockwise_controlnet_inputs:
for controlnet_input in controlnet_inputs:
image = controlnet_input.image
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
@@ -553,21 +514,22 @@ class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
conditionings.append(image)
return {"blockwise_controlnet_conditioning": conditionings}
return {return_key: conditionings}
def model_fn_qwen_image(
dit: QwenImageDiT = None,
blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None,
controlnet: QwenImageMultiControlNet = None,
blockwise_controlnet: QwenImageBlockWiseControlNet = None,
latents=None,
timestep=None,
prompt_emb=None,
prompt_emb_mask=None,
height=None,
width=None,
controlnet_inputs=None,
controlnet_conditionings=None,
blockwise_controlnet_conditioning=None,
blockwise_controlnet_inputs=None,
progress_id=0,
num_inference_steps=1,
entity_prompt_emb=None,
@@ -578,6 +540,23 @@ def model_fn_qwen_image(
use_gradient_checkpointing_offload=False,
**kwargs
):
# ControlNet
if controlnet_conditionings is not None:
controlnet_extra_kwargs = {
"latents": latents,
"timestep": timestep,
"prompt_emb": prompt_emb,
"prompt_emb_mask": prompt_emb_mask,
"height": height,
"width": width,
"use_gradient_checkpointing": use_gradient_checkpointing,
"use_gradient_checkpointing_offload": use_gradient_checkpointing_offload,
}
res_stack = controlnet(
controlnet_conditionings, controlnet_inputs, progress_id, num_inference_steps,
**controlnet_extra_kwargs
)
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
timestep = timestep / 1000
@@ -596,11 +575,14 @@ def model_fn_qwen_image(
text = dit.txt_in(dit.txt_norm(prompt_emb))
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
attention_mask = None
if blockwise_controlnet_conditioning is not None:
blockwise_controlnet_conditioning = blockwise_controlnet.preprocess(
blockwise_controlnet_inputs, blockwise_controlnet_conditioning)
if blockwise_controlnet_conditioning is not None:
blockwise_controlnet_conditioning = rearrange(
blockwise_controlnet_conditioning[0], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2
)
blockwise_controlnet_conditioning = blockwise_controlnet.process_controlnet_conditioning(blockwise_controlnet_conditioning)
# blockwise_controlnet_conditioning =
for block_id, block in enumerate(dit.transformer_blocks):
text, image = gradient_checkpoint_forward(
block,
@@ -613,13 +595,11 @@ def model_fn_qwen_image(
attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
)
if blockwise_controlnet_conditioning is not None:
image = image + blockwise_controlnet.blockwise_forward(
image=image, conditionings=blockwise_controlnet_conditioning,
controlnet_inputs=blockwise_controlnet_inputs, block_id=block_id,
progress_id=progress_id, num_inference_steps=num_inference_steps,
)
if blockwise_controlnet is not None:
image = image + blockwise_controlnet.blockwise_forward(image, blockwise_controlnet_conditioning, block_id)
if controlnet_conditionings is not None:
image = image + res_stack[block_id]
image = dit.norm_out(image, conditioning)
image = dit.proj_out(image)

View File

@@ -1021,10 +1021,6 @@ def model_fn_wan_video(
torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
]).flatten()
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1)
t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks]
t = t_chunks[get_sequence_parallel_rank()]
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
else:
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))

View File

@@ -344,17 +344,8 @@ class DiffusionTrainingModule(torch.nn.Module):
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
model = inject_adapter_in_model(lora_config, model)
return model
def mapping_lora_state_dict(self, state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if "lora_A.weight" in key or "lora_B.weight" in key:
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
new_state_dict[new_key] = value
return new_state_dict
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
trainable_param_names = self.trainable_param_names()
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
@@ -476,7 +467,6 @@ def wan_parser():
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
@@ -485,7 +475,6 @@ def wan_parser():
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
return parser
@@ -509,7 +498,6 @@ def flux_parser():
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
@@ -518,7 +506,6 @@ def flux_parser():
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
return parser
@@ -543,13 +530,12 @@ def qwen_image_parser():
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
return parser

View File

@@ -139,20 +139,6 @@ class BasePipeline(torch.nn.Module):
else:
model.eval()
model.requires_grad_(False)
def blend_with_mask(self, base, addition, mask):
return base * (1 - mask) + addition * mask
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
timestep = scheduler.timesteps[progress_id]
if inpaint_mask is not None:
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
latents_next = scheduler.step(noise_pred, timestep, latents)
return latents_next
@dataclass

View File

@@ -255,7 +255,6 @@ The script includes the following parameters:
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas.
* Training
* `--learning_rate`: Learning rate.
* `--weight_decay`: Weight decay.
* `--num_epochs`: Number of epochs.
* `--output_path`: Save path.
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
@@ -266,7 +265,6 @@ The script includes the following parameters:
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.
* Extra Model Inputs
* `--extra_inputs`: Extra model inputs, separated by commas.
* VRAM Management

View File

@@ -255,7 +255,6 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。
* 训练
* `--learning_rate`: 学习率。
* `--weight_decay`:权重衰减大小。
* `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
@@ -266,7 +265,6 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪一层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径LoRA 将从此检查点加载。
* 额外模型输入
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
* 显存管理

View File

@@ -1,5 +1,4 @@
import torch, os, json
from diffsynth import load_state_dict
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser
from diffsynth.models.lora import FluxLoRAConverter
@@ -12,7 +11,7 @@ class FluxTrainingModule(DiffusionTrainingModule):
self,
model_paths=None, model_id_with_origin_paths=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32, lora_checkpoint=None,
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
@@ -41,12 +40,6 @@ class FluxTrainingModule(DiffusionTrainingModule):
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(self.pipe, lora_base_model, model)
# Store other configs
@@ -113,7 +106,6 @@ if __name__ == "__main__":
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
@@ -123,7 +115,7 @@ if __name__ == "__main__":
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
launch_training_task(
dataset, model, model_logger, optimizer, scheduler,

View File

@@ -40,15 +40,11 @@ image.save("image.jpg")
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image )|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
## Model Inference
@@ -178,13 +174,10 @@ After enabling VRAM management, the framework will automatically choose a memory
<summary>Inference Acceleration</summary>
* FP8 Quantization: Choose the appropriate quantization method based on your hardware and requirements.
* GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_low_vram/Qwen-Image.py](./model_inference_low_vram/Qwen-Image.py)
* GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
* GPUs that support FP8 operations (e.g., H200, etc.): Please install [Flash Attention 3](https://github.com/Dao-AILab/flash-attention). Otherwise, FP8 acceleration will only apply to Linear layers.
* Faster inference but higher VRAM usage: Use [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
* Slightly slower inference but lower VRAM usage: Use [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
* Distillation acceleration: We trained two distillation models for fast inference at `cfg_scale=1` and `num_inference_steps=15`.
* [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full): Full distillation version. Better image quality but lower LoRA compatibility. Use [./model_inference/Qwen-Image-Distill-Full.py](./model_inference/Qwen-Image-Distill-Full.py).
* [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA): LoRA distillation version. Slightly lower image quality but better LoRA compatibility. Use [./model_inference/Qwen-Image-Distill-LoRA.py](./model_inference/Qwen-Image-Distill-LoRA.py).
</details>
@@ -238,7 +231,6 @@ The script includes the following parameters:
* `--tokenizer_path`: Tokenizer path. Leave empty to auto-download.
* Training
* `--learning_rate`: Learning rate.
* `--weight_decay`: Weight decay.
* `--num_epochs`: Number of epochs.
* `--output_path`: Save path.
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
@@ -249,13 +241,14 @@ The script includes the following parameters:
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.
* Extra Model Inputs
* `--extra_inputs`: Extra model inputs, separated by commas.
* VRAM Management
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Others
* `--align_to_opensource_format`: Whether to align DiT LoRA format with open-source version. Only works for LoRA training.
In addition, the training framework is built on [`accelerate`](https://huggingface.co/docs/accelerate/index). Run `accelerate config` before training to set GPU-related settings. For some training tasks (e.g., full training of 20B model), we provide suggested `accelerate` config files. Check the corresponding training script for details.

View File

@@ -40,15 +40,11 @@ image.save("image.jpg")
## 模型总览
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
## 模型推理
@@ -178,13 +174,10 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
<summary>推理加速</summary>
* FP8 量化:根据您的硬件与需求,请选择合适的量化方式
* GPU 不支持 FP8 计算(例如 A100、4090 等FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_low_vram/Qwen-Image.py](./model_inference_low_vram/Qwen-Image.py)
* GPU 不支持 FP8 计算(例如 A100、4090 等FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
* GPU 支持 FP8 运算(例如 H200 等):请安装 [Flash Attention 3](https://github.com/Dao-AILab/flash-attention),否则 FP8 加速仅对 Linear 层生效
* 更快的速度,但更大的显存:请使用 [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
* 稍慢的速度,但更小的显存:请使用 [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
* 蒸馏加速:我们训练了两个蒸馏加速模型,可以在 `cfg_scale=1``num_inference_steps=15` 设置下进行快速推理
* [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full):全量蒸馏训练版本,更好的生成效果,稍差的 LoRA 兼容性,请使用 [./model_inference/Qwen-Image-Distill-Full.py](./model_inference/Qwen-Image-Distill-Full.py)
* [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)LoRA 蒸馏训练版本,稍差的生成效果,更好的 LoRA 兼容性,请使用 [./model_inference/Qwen-Image-Distill-LoRA.py](./model_inference/Qwen-Image-Distill-LoRA.py)
</details>
@@ -238,7 +231,6 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
* `--tokenizer_path`: tokenizer 路径,留空将会自动下载。
* 训练
* `--learning_rate`: 学习率。
* `--weight_decay`:权重衰减大小。
* `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
@@ -249,13 +241,14 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪一层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径LoRA 将从此检查点加载。
* 额外模型输入
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
* 显存管理
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
* `--gradient_accumulation_steps`: 梯度累积步数。
* 其他
* `--align_to_opensource_format`: 是否将 DiT LoRA 的格式与开源版本对齐,仅对 LoRA 训练生效。
此外,训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,在开始训练前运行 `accelerate config` 可配置 GPU 的相关参数。对于部分模型训练(例如 20B 模型的全量训练)脚本,我们提供了建议的 `accelerate` 配置文件,可在对应的训练脚本中查看。

View File

@@ -1,31 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
from PIL import Image
import torch
from modelscope import dataset_snapshot_download
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="canny/image_1.jpg"
)
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
image = pipe(
prompt, seed=0,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
)
image.save("image.jpg")

View File

@@ -1,32 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
from PIL import Image
import torch
from modelscope import dataset_snapshot_download
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="depth/image_1.jpg"
)
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328))
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(
prompt, seed=0,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
)
image.save("image.jpg")

View File

@@ -1,33 +0,0 @@
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="inpaint/*.jpg"
)
prompt = "a cat with sunglasses"
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328))
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328))
image = pipe(
prompt, seed=0,
input_image=controlnet_image, inpaint_mask=inpaint_mask,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
num_inference_steps=40,
)
image.save("image.jpg")

View File

@@ -1,20 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import snapshot_download
import torch
snapshot_download("DiffSynth-Studio/Qwen-Image-Distill-LoRA", local_dir="models/DiffSynth-Studio/Qwen-Image-Distill-LoRA")
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Distill-LoRA/model.safetensors")
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)
image.save("image.jpg")

View File

@@ -1,32 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
from PIL import Image
import torch
from modelscope import dataset_snapshot_download
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="canny/image_1.jpg"
)
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
image = pipe(
prompt, seed=0,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
)
image.save("image.jpg")

View File

@@ -1,33 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
from PIL import Image
import torch
from modelscope import dataset_snapshot_download
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="depth/image_1.jpg"
)
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328))
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(
prompt, seed=0,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
)
image.save("image.jpg")

View File

@@ -1,34 +0,0 @@
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.enable_vram_management()
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="inpaint/*.jpg"
)
prompt = "a cat with sunglasses"
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328))
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328))
image = pipe(
prompt, seed=0,
input_image=controlnet_image, inpaint_mask=inpaint_mask,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
num_inference_steps=40,
)
image.save("image.jpg")

View File

@@ -1,22 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import snapshot_download
import torch
# Please do not use float8 on this model
snapshot_download("DiffSynth-Studio/Qwen-Image-Distill-LoRA", local_dir="models/DiffSynth-Studio/Qwen-Image-Distill-LoRA")
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.enable_vram_management()
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Distill-LoRA/model.safetensors")
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)
image.save("image.jpg")

View File

@@ -1,129 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
from PIL import Image, ImageDraw, ImageFont
from modelscope import dataset_snapshot_download, snapshot_download
import random
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
# Create a blank image for overlays
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
colors = [
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
]
# Generate random colors for each mask
if use_random_colors:
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
# Font settings
try:
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
except IOError:
font = ImageFont.load_default(font_size)
# Overlay each mask onto the overlay image
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
# Convert mask to RGBA mode
mask_rgba = mask.convert('RGBA')
mask_data = mask_rgba.getdata()
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
mask_rgba.putdata(new_data)
# Draw the mask prompt text on the mask
draw = ImageDraw.Draw(mask_rgba)
mask_bbox = mask.getbbox() # Get the bounding box of the mask
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
# Alpha composite the overlay with this mask
overlay = Image.alpha_composite(overlay, mask_rgba)
# Composite the overlay onto the original image
result = Image.alpha_composite(image.convert('RGBA'), overlay)
# Save or display the resulting image
result.save(output_path)
return result
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png")
masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
negative_prompt = ""
for seed in seeds:
# generate image
image = pipe(
prompt=global_prompt,
cfg_scale=4.0,
negative_prompt=negative_prompt,
num_inference_steps=30,
seed=seed,
height=1024,
width=1024,
eligen_entity_prompts=entity_prompts,
eligen_entity_masks=masks,
)
image.save(f"eligen_example_{example_id}_{seed}.png")
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.enable_vram_management()
snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen", allow_file_pattern="model.safetensors")
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors")
# example 1
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
example(pipe, [0], 1, global_prompt, entity_prompts)
# example 2
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "yellow belt"]
example(pipe, [0], 2, global_prompt, entity_prompts)
# example 3
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
example(pipe, [27], 3, global_prompt, entity_prompts)
# example 4
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
example(pipe, [21], 4, global_prompt, entity_prompts)
# example 5
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
example(pipe, [0], 5, global_prompt, entity_prompts)
# example 7, same prompt with different seeds
seeds = range(5, 9)
global_prompt = "A beautiful asia woman wearing white dress, holding a mirror, with a forest background."
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
example(pipe, seeds, 7, global_prompt, entity_prompts)

View File

@@ -0,0 +1,36 @@
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path "" \
--dataset_metadata_path data/t2i_dataset_annotations/blip3o/blip3o_control_images_train_for_diffsynth.jsonl \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_paths '[
[
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
],
[
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
],
"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors",
"models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors"
]' \
--learning_rate 1e-3 \
--num_epochs 1000000 \
--remove_prefix_in_ckpt "pipe.blockwise_controlnet." \
--output_path "./models/train/Qwen-Image-BlockWiseControlNet_full_lr1e-3_wd1e-6" \
--trainable_models "blockwise_controlnet" \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--save_steps 2000

View File

@@ -1,38 +0,0 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \
--data_file_keys "image,blockwise_controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full" \
--trainable_models "blockwise_controlnet" \
--extra_inputs "blockwise_controlnet_image" \
--use_gradient_checkpointing \
--find_unused_parameters
# If you want to pre-train a Blockwise ControlNet from scratch,
# please run the following script to first generate the initialized model weights file,
# and then start training with a high learning rate (1e-3).
# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py
# accelerate launch examples/qwen_image/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \
# --data_file_keys "image,blockwise_controlnet_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
# --model_paths '["models/blockwise_controlnet.safetensors"]' \
# --learning_rate 1e-3 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
# --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full" \
# --trainable_models "blockwise_controlnet" \
# --extra_inputs "blockwise_controlnet_image" \
# --use_gradient_checkpointing \
# --find_unused_parameters

View File

@@ -1,38 +0,0 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \
--data_file_keys "image,blockwise_controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full" \
--trainable_models "blockwise_controlnet" \
--extra_inputs "blockwise_controlnet_image" \
--use_gradient_checkpointing \
--find_unused_parameters
# If you want to pre-train a Blockwise ControlNet from scratch,
# please run the following script to first generate the initialized model weights file,
# and then start training with a high learning rate (1e-3).
# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py
# accelerate launch examples/qwen_image/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \
# --data_file_keys "image,blockwise_controlnet_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
# --model_paths '["models/blockwise_controlnet.safetensors"]' \
# --learning_rate 1e-3 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
# --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full" \
# --trainable_models "blockwise_controlnet" \
# --extra_inputs "blockwise_controlnet_image" \
# --use_gradient_checkpointing \
# --find_unused_parameters

View File

@@ -1,38 +0,0 @@
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \
--data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full" \
--trainable_models "blockwise_controlnet" \
--extra_inputs "blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
--use_gradient_checkpointing \
--find_unused_parameters
# If you want to pre-train a Inpaint Blockwise ControlNet from scratch,
# please run the following script to first generate the initialized model weights file,
# and then start training with a high learning rate (1e-3).
# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Inpaint-Initialize.py
# accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \
# --data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
# --model_paths '["models/blockwise_controlnet_inpaint.safetensors"]' \
# --learning_rate 1e-3 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
# --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full" \
# --trainable_models "blockwise_controlnet" \
# --extra_inputs "blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
# --use_gradient_checkpointing \
# --find_unused_parameters

View File

@@ -0,0 +1,35 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 80000 \
--model_paths '[
[
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
],
[
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
],
"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors",
"models/controlnet.safetensors"
]' \
--learning_rate 1e-5 \
--num_epochs 1000000 \
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
--output_path "./models/train/Qwen-Image-ControlNet_full" \
--trainable_models "controlnet" \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing \
--save_steps 100

View File

@@ -9,5 +9,4 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Distill-Full_full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters
--use_gradient_checkpointing

View File

@@ -9,5 +9,4 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image_full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters
--use_gradient_checkpointing

View File

@@ -1,5 +1,5 @@
# This script is for initializing a Qwen-Image-Blockwise-ControlNet
from diffsynth import hash_state_dict_keys
# This script is for initializing a Qwen-Image-ControlNet
from diffsynth import load_state_dict, hash_state_dict_keys
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
import torch
from safetensors.torch import save_file
@@ -10,4 +10,4 @@ controlnet.init_weight()
state_dict_controlnet = controlnet.state_dict()
print(hash_state_dict_keys(state_dict_controlnet))
save_file(state_dict_controlnet, "models/blockwise_controlnet.safetensors")
save_file(state_dict_controlnet, "models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors")

View File

@@ -0,0 +1,34 @@
# This script is for initializing a Qwen-Image-ControlNet
from diffsynth import load_state_dict, hash_state_dict_keys
from diffsynth.pipelines.qwen_image import QwenImageControlNet
import torch
from safetensors.torch import save_file
state_dict_dit = {}
for i in range(1, 10):
state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda"))
controlnet = QwenImageControlNet().to(dtype=torch.bfloat16, device="cuda")
state_dict_controlnet = controlnet.state_dict()
state_dict_init = {}
for k in state_dict_controlnet:
if k in state_dict_dit:
if state_dict_dit[k].shape == state_dict_controlnet[k].shape:
state_dict_init[k] = state_dict_dit[k]
elif k == "img_in.weight":
state_dict_init[k] = torch.concat(
[
state_dict_dit[k],
state_dict_dit[k],
],
dim=-1
)
else:
print("Zero Initialized:", k)
state_dict_init[k] = torch.zeros_like(state_dict_controlnet[k])
controlnet.load_state_dict(state_dict_init)
print(hash_state_dict_keys(state_dict_init))
save_file(state_dict_init, "models/controlnet.safetensors")

View File

@@ -1,17 +0,0 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \
--data_file_keys "image,blockwise_controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--extra_inputs "blockwise_controlnet_image" \
--use_gradient_checkpointing \
--find_unused_parameters

View File

@@ -1,17 +0,0 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \
--data_file_keys "image,blockwise_controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--extra_inputs "blockwise_controlnet_image" \
--use_gradient_checkpointing \
--find_unused_parameters

View File

@@ -1,17 +0,0 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \
--data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--extra_inputs "blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
--use_gradient_checkpointing \
--find_unused_parameters

View File

@@ -11,5 +11,5 @@ accelerate launch examples/qwen_image/model_training/train.py \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters
--align_to_opensource_format \
--use_gradient_checkpointing

View File

@@ -12,6 +12,7 @@ accelerate launch examples/qwen_image/model_training/train.py \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--align_to_opensource_format \
--extra_inputs "eligen_entity_masks,eligen_entity_prompts" \
--use_gradient_checkpointing \
--find_unused_parameters

View File

@@ -11,6 +11,7 @@ accelerate launch examples/qwen_image/model_training/train.py \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--align_to_opensource_format \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -1,12 +0,0 @@
# This script is for initializing a Inpaint Qwen-Image-ControlNet
import torch
from diffsynth import hash_state_dict_keys
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
from safetensors.torch import save_file
controlnet = QwenImageBlockWiseControlNet(additional_in_dim=4).to(dtype=torch.bfloat16, device="cuda")
controlnet.init_weight()
state_dict_controlnet = controlnet.state_dict()
print(hash_state_dict_keys(state_dict_controlnet))
save_file(state_dict_controlnet, "models/blockwise_controlnet_inpaint.safetensors")

View File

@@ -1,8 +1,7 @@
import torch, os, json
from diffsynth import load_state_dict
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.pipelines.flux_image_new import ControlNetInput
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
from diffsynth.models.lora import QwenImageLoRAConverter
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -13,7 +12,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
model_paths=None, model_id_with_origin_paths=None,
tokenizer_path=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
lora_base_model=None, lora_target_modules="", lora_rank=32,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
@@ -45,12 +44,6 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(self.pipe, lora_base_model, model)
# Store other configs
@@ -80,18 +73,15 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
}
# Extra inputs
controlnet_input, blockwise_controlnet_input = {}, {}
controlnet_input = {}
for extra_input in self.extra_inputs:
if extra_input.startswith("blockwise_controlnet_"):
blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input]
elif extra_input.startswith("controlnet_"):
inputs_shared[extra_input] = data[extra_input]
if extra_input.startswith("controlnet_"):
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
else:
inputs_shared[extra_input] = data[extra_input]
if len(controlnet_input) > 0:
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
if len(blockwise_controlnet_input) > 0:
inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units:
@@ -119,13 +109,16 @@ if __name__ == "__main__":
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
)
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=0.000001)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
launch_training_task(
dataset, model, model_logger, optimizer, scheduler,

View File

@@ -0,0 +1,38 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
import torch
from PIL import Image
from diffsynth.controlnets.processors import Annotator
import os
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig(path="models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("models/train/Qwen-Image-BlockWiseControlNet_full_lr1e-3_wd1e-6/step-26000.safetensors")
pipe.blockwise_controlnet.load_state_dict(state_dict)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = Image.open("test_image.jpg").convert("RGB").resize((1024, 1024))
canny_image = Annotator("canny")(image)
canny_image.save("canny_image_test.jpg")
controlnet_input = ControlNetInput(
image=canny_image,
scale=1.0,
processor_id="canny",
)
for seed in range(100, 200):
image = pipe(prompt, seed=seed, height=1024, width=1024, controlnet_inputs=[controlnet_input], num_inference_steps=30, cfg_scale=4.0)
image.save(f"test_image_controlnet_step2k_1_{seed}.jpg")

View File

@@ -1,31 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
from PIL import Image
import torch
from modelscope import dataset_snapshot_download
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig(path="models/train/Qwen-Image-Blockwise-ControlNet-Canny_full/epoch-1.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="canny/image_1.jpg"
)
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
image = pipe(
prompt, seed=0,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
)
image.save("image.jpg")

View File

@@ -1,31 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
from PIL import Image
import torch
from modelscope import dataset_snapshot_download
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig(path="models/train/Qwen-Image-Blockwise-ControlNet-Depth_full/epoch-1.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="depth/image_1.jpg"
)
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328))
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(
prompt, seed=0,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
)
image.save("image.jpg")

View File

@@ -1,32 +0,0 @@
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig(path="models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full/epoch-1.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="inpaint/*.jpg"
)
prompt = "a cat with sunglasses"
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
image = pipe(
prompt, seed=0,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
height=1024, width=1024,
num_inference_steps=40,
)
image.save("image.jpg")

View File

@@ -1,32 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
from PIL import Image
import torch
from modelscope import dataset_snapshot_download
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Blockwise-ControlNet-Canny_lora/epoch-4.safetensors")
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="canny/image_1.jpg"
)
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
image = pipe(
prompt, seed=0,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
)
image.save("image.jpg")

View File

@@ -1,33 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
from PIL import Image
import torch
from modelscope import dataset_snapshot_download
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Blockwise-ControlNet-Depth_lora/epoch-4.safetensors")
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="depth/image_1.jpg"
)
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328))
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(
prompt, seed=0,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
)
image.save("image.jpg")

View File

@@ -1,34 +0,0 @@
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_lora/epoch-4.safetensors")
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="inpaint/*.jpg"
)
prompt = "a cat with sunglasses"
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
image = pipe(
prompt, seed=0,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
height=1024, width=1024,
num_inference_steps=40,
)
image.save("image.jpg")

View File

@@ -288,7 +288,6 @@ The script includes the following parameters:
* `--min_timestep_boundary`: Minimum value of the timestep interval, ranging from 0 to 1. Default is 1. This needs to be manually set only when training mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B).
* Training
* `--learning_rate`: Learning rate.
* `--weight_decay`: Weight decay.
* `--num_epochs`: Number of epochs.
* `--output_path`: Output save path.
* `--remove_prefix_in_ckpt`: Remove prefix in ckpt.
@@ -299,7 +298,6 @@ The script includes the following parameters:
* `--lora_base_model`: Which model LoRA is added to.
* `--lora_target_modules`: Which layers LoRA is added to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.
* Extra Inputs
* `--extra_inputs`: Additional model inputs, comma-separated.
* VRAM Management

View File

@@ -290,7 +290,6 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
* `--min_timestep_boundary`: Timestep 区间最小值,范围为 01默认为 1仅在多 DiT 的混合模型训练中需要手动设置,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。
* 训练
* `--learning_rate`: 学习率。
* `--weight_decay`:权重衰减大小。
* `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
@@ -301,7 +300,6 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪一层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径LoRA 将从此检查点加载。
* 额外模型输入
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
* 显存管理

View File

@@ -28,6 +28,5 @@ video = pipe(
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
input_image=input_image,
switch_DiT_boundary=0.9,
)
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -13,9 +13,8 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
--trainable_models "dit" \
--extra_inputs "input_image" \
--use_gradient_checkpointing_offload \
--max_timestep_boundary 0.358 \
--min_timestep_boundary 0
# boundary corresponds to timesteps [900, 1000]
--max_timestep_boundary 1 \
--min_timestep_boundary 0.875
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
@@ -32,6 +31,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
--trainable_models "dit" \
--extra_inputs "input_image" \
--use_gradient_checkpointing_offload \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.358
# boundary corresponds to timesteps [0, 900)
--max_timestep_boundary 0.875 \
--min_timestep_boundary 0

View File

@@ -11,9 +11,8 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-T2V-A14B_high_noise_full" \
--trainable_models "dit" \
--max_timestep_boundary 0.417 \
--min_timestep_boundary 0
# boundary corresponds to timesteps [875, 1000]
--max_timestep_boundary 1 \
--min_timestep_boundary 0.875
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
@@ -28,6 +27,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-T2V-A14B_low_noise_full" \
--trainable_models "dit" \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.417
# boundary corresponds to timesteps [0, 875)
--max_timestep_boundary 0.875 \
--min_timestep_boundary 0

View File

@@ -14,9 +14,8 @@ accelerate launch examples/wanvideo/model_training/train.py \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image" \
--max_timestep_boundary 0.358 \
--min_timestep_boundary 0
# boundary corresponds to timesteps [900, 1000]
--max_timestep_boundary 1 \
--min_timestep_boundary 0.875
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
@@ -34,6 +33,5 @@ accelerate launch examples/wanvideo/model_training/train.py \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image" \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.358
# boundary corresponds to timesteps [0, 900)
--max_timestep_boundary 0.875 \
--min_timestep_boundary 0

View File

@@ -13,9 +13,8 @@ accelerate launch examples/wanvideo/model_training/train.py \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--max_timestep_boundary 0.417 \
--min_timestep_boundary 0
# boundary corresponds to timesteps [875, 1000]
--max_timestep_boundary 1 \
--min_timestep_boundary 0.875
accelerate launch examples/wanvideo/model_training/train.py \
@@ -33,6 +32,5 @@ accelerate launch examples/wanvideo/model_training/train.py \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.417
# boundary corresponds to timesteps [0, 875)
--max_timestep_boundary 0.875 \
--min_timestep_boundary 0

View File

@@ -1,5 +1,4 @@
import torch, os, json
from diffsynth import load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -11,7 +10,7 @@ class WanTrainingModule(DiffusionTrainingModule):
self,
model_paths=None, model_id_with_origin_paths=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, lora_checkpoint=None,
lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
@@ -42,12 +41,6 @@ class WanTrainingModule(DiffusionTrainingModule):
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(self.pipe, lora_base_model, model)
# Store other configs
@@ -119,7 +112,6 @@ if __name__ == "__main__":
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
max_timestep_boundary=args.max_timestep_boundary,
@@ -129,7 +121,7 @@ if __name__ == "__main__":
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt
)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
launch_training_task(
dataset, model, model_logger, optimizer, scheduler,

View File

@@ -1,119 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, QwenImageUnit_PromptEmbedder, load_state_dict
import torch, os
from tqdm import tqdm
from diffsynth.models.svd_unet import TemporalTimesteps
from einops import rearrange, repeat
class ValueEncoder(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32):
super().__init__()
self.value_emb = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
self.positional_emb = torch.nn.Parameter(torch.randn(1, value_emb_length, dim_out))
self.proj_value = torch.nn.Linear(dim_in, dim_out)
self.proj_out = torch.nn.Linear(dim_out, dim_out)
self.value_emb_length = value_emb_length
def forward(self, value):
value = value * 1
emb = self.value_emb(value).to(value.dtype)
emb = self.proj_value(emb)
emb = repeat(emb, "b d -> b s d", s=self.value_emb_length)
emb = emb + self.positional_emb.to(dtype=emb.dtype, device=emb.device)
emb = torch.nn.functional.silu(emb)
emb = self.proj_out(emb)
return emb
class TextInterpolationModel(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32, num_heads=32):
super().__init__()
self.to_q = ValueEncoder(dim_in=dim_in, dim_out=dim_out, value_emb_length=value_emb_length)
self.xk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.yk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.xv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.yv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.to_k = torch.nn.Linear(dim_out, dim_out, bias=False)
self.to_v = torch.nn.Linear(dim_out, dim_out, bias=False)
self.to_out = torch.nn.Linear(dim_out, dim_out)
self.num_heads = num_heads
def forward(self, value, x, y):
q = self.to_q(value)
k = self.to_k(torch.concat([x + self.xk_emb, y + self.yk_emb], dim=1))
v = self.to_v(torch.concat([x + self.xv_emb, y + self.yv_emb], dim=1))
q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
out = rearrange(out, 'b h s d -> b s (h d)')
out = self.to_out(out)
return out
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
unit = QwenImageUnit_PromptEmbedder()
dataset_prompt = [
(
"超级黑暗的画面,整体在黑暗中,暗无天日,暗淡无光,阴森黑暗,几乎全黑",
"超级明亮的画面,爆闪,相机过曝,整个画面都是白色的眩光,几乎全是白色",
),
]
dataset_tensors = []
for prompt_x, prompt_y in tqdm(dataset_prompt):
with torch.no_grad():
x = unit.process(pipe, prompt_x)["prompt_emb"]
y = unit.process(pipe, prompt_y)["prompt_emb"]
dataset_tensors.append((x, y))
model = TextInterpolationModel().to(dtype=torch.bfloat16, device="cuda")
model.load_state_dict(load_state_dict("models/interpolate.pth"))
def sample_tokens(emb, p):
perm = torch.randperm(emb.shape[1])[:max(0, int(emb.shape[1]*p))]
return emb[:, perm]
def loss_fn(x, y):
s, l = x.shape[1], y.shape[1]
x = repeat(x, "b s d -> b s l d", l=l)
y = repeat(y, "b l d -> b s l d", s=s)
d = torch.square(x - y).mean(dim=-1)
loss_x = d.min(dim=1).values.mean()
loss_y = d.min(dim=2).values.mean()
return loss_x + loss_y
def get_target(x, y, p):
x = sample_tokens(x, 1-p)
y = sample_tokens(y, p)
return torch.concat([x, y], dim=1)
name = "brightness"
for i in range(6):
v = i/5
with torch.no_grad():
data_id = 0
x, y = dataset_tensors[data_id]
x, y = x.to("cuda"), y.to("cuda")
value = torch.tensor([v], dtype=torch.bfloat16, device="cuda")
value_emb = model(value, x, y)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40, extra_prompt_emb=value_emb)
os.makedirs(f"data/qwen_image_value/{name}", exist_ok=True)
image.save(f"data/qwen_image_value/{name}/image_{v}.jpg")

View File

@@ -1,121 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, QwenImageUnit_PromptEmbedder
import torch
from tqdm import tqdm
from diffsynth.models.svd_unet import TemporalTimesteps
from einops import rearrange, repeat
class ValueEncoder(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32):
super().__init__()
self.value_emb = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
self.positional_emb = torch.nn.Parameter(torch.randn(1, value_emb_length, dim_out))
self.proj_value = torch.nn.Linear(dim_in, dim_out)
self.proj_out = torch.nn.Linear(dim_out, dim_out)
self.value_emb_length = value_emb_length
def forward(self, value):
value = value * 1
emb = self.value_emb(value).to(value.dtype)
emb = self.proj_value(emb)
emb = repeat(emb, "b d -> b s d", s=self.value_emb_length)
emb = emb + self.positional_emb.to(dtype=emb.dtype, device=emb.device)
emb = torch.nn.functional.silu(emb)
emb = self.proj_out(emb)
return emb
class TextInterpolationModel(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32, num_heads=32):
super().__init__()
self.to_q = ValueEncoder(dim_in=dim_in, dim_out=dim_out, value_emb_length=value_emb_length)
self.xk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.yk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.xv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.yv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
self.to_k = torch.nn.Linear(dim_out, dim_out, bias=False)
self.to_v = torch.nn.Linear(dim_out, dim_out, bias=False)
self.to_out = torch.nn.Linear(dim_out, dim_out)
self.num_heads = num_heads
def forward(self, value, x, y):
q = self.to_q(value)
k = self.to_k(torch.concat([x + self.xk_emb, y + self.yk_emb], dim=1))
v = self.to_v(torch.concat([x + self.xv_emb, y + self.yv_emb], dim=1))
q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
out = rearrange(out, 'b h s d -> b s (h d)')
out = self.to_out(out)
return out
def sample_tokens(emb, p):
perm = torch.randperm(emb.shape[1])[:max(0, int(emb.shape[1]*p))]
return emb[:, perm]
def loss_fn(x, y):
s, l = x.shape[1], y.shape[1]
x = repeat(x, "b s d -> b s l d", l=l)
y = repeat(y, "b l d -> b s l d", s=s)
d = torch.square(x - y).mean(dim=-1)
loss_x = d.min(dim=1).values.mean()
loss_y = d.min(dim=2).values.mean()
return loss_x + loss_y
def get_target(x, y, p):
x = sample_tokens(x, 1-p)
y = sample_tokens(y, p)
return torch.concat([x, y], dim=1)
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
unit = QwenImageUnit_PromptEmbedder()
dataset_prompt = [
(
"超级黑暗的画面,整体在黑暗中,暗无天日,暗淡无光,阴森黑暗,几乎全黑",
"超级明亮的画面,爆闪,相机过曝,整个画面都是白色的眩光,几乎全是白色",
),
]
dataset_tensors = []
for prompt_x, prompt_y in tqdm(dataset_prompt):
with torch.no_grad():
x = unit.process(pipe, prompt_x)["prompt_emb"].to(dtype=torch.float32, device="cpu")
y = unit.process(pipe, prompt_y)["prompt_emb"].to(dtype=torch.float32, device="cpu")
dataset_tensors.append((x, y))
model = TextInterpolationModel().to(dtype=torch.float32, device="cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for step_id, step in enumerate(tqdm(range(100000))):
optimizer.zero_grad()
data_id = torch.randint(0, len(dataset_tensors), size=(1,)).item()
x, y = dataset_tensors[data_id]
x, y = x.to("cuda"), y.to("cuda")
value = torch.rand((1,), dtype=torch.float32, device="cuda")
out = model(value, x, y)
loss = loss_fn(out, x) * (1 - value) + loss_fn(out, y) * value
loss.backward()
optimizer.step()
if (step_id + 1) % 1000 == 0:
print(loss)
torch.save(model.state_dict(), f"models/interpolate_{step+1}.pth")