diff --git a/README.md b/README.md
index 5e008e0..93aa836 100644
--- a/README.md
+++ b/README.md
@@ -33,6 +33,10 @@ We believe that a well-developed open-source code framework can lower the thresh
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
+- **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models.
+
+- **January 27, 2026**: [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) is released, and our [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) model is released concurrently. You can use it in [ModelScope Studios](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L). For details, see the [documentation](/docs/zh/Model_Details/Z-Image.md).
+
- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available.
- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)).
@@ -269,9 +273,14 @@ image.save("image.jpg")
Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)
-| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
+|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|-|-|-|-|-|-|-|
+|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
+|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
+|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
+|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
+|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
@@ -410,6 +419,7 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
+|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
diff --git a/README_zh.md b/README_zh.md
index a1619a5..2aee367 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -33,6 +33,10 @@ DiffSynth 目前包括两个开源项目:
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
+- **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。
+
+- **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)。
+
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog([中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
@@ -271,7 +275,12 @@ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
+|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
+|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
+|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
+|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
+|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
@@ -410,6 +419,7 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
+|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py
index f427fa6..9ff7ea6 100644
--- a/diffsynth/configs/model_configs.py
+++ b/diffsynth/configs/model_configs.py
@@ -589,6 +589,14 @@ z_image_series = [
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 128},
},
+ {
+ # Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
+ "model_hash": "1392adecee344136041e70553f875f31",
+ "model_name": "z_image_text_encoder",
+ "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
+ "extra_kwargs": {"model_size": "0.6B"},
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
+ },
]
ltx2_series = [
diff --git a/diffsynth/core/loader/config.py b/diffsynth/core/loader/config.py
index 88b46a0..d4ce83c 100644
--- a/diffsynth/core/loader/config.py
+++ b/diffsynth/core/loader/config.py
@@ -1,5 +1,5 @@
import torch, glob, os
-from typing import Optional, Union
+from typing import Optional, Union, Dict
from dataclasses import dataclass
from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download
@@ -23,13 +23,14 @@ class ModelConfig:
computation_device: Optional[Union[str, torch.device]] = None
computation_dtype: Optional[torch.dtype] = None
clear_parameters: bool = False
+ state_dict: Dict[str, torch.Tensor] = None
def check_input(self):
if self.path is None and self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
def parse_original_file_pattern(self):
- if self.origin_file_pattern is None or self.origin_file_pattern == "":
+ if self.origin_file_pattern in [None, "", "./"]:
return "*"
elif self.origin_file_pattern.endswith("/"):
return self.origin_file_pattern + "*"
@@ -98,7 +99,7 @@ class ModelConfig:
if self.require_downloading():
self.download()
if self.path is None:
- if self.origin_file_pattern is None or self.origin_file_pattern == "":
+ if self.origin_file_pattern in [None, "", "./"]:
self.path = os.path.join(self.local_model_path, self.model_id)
else:
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
diff --git a/diffsynth/core/loader/file.py b/diffsynth/core/loader/file.py
index 8f66961..67d8815 100644
--- a/diffsynth/core/loader/file.py
+++ b/diffsynth/core/loader/file.py
@@ -2,16 +2,25 @@ from safetensors import safe_open
import torch, hashlib
-def load_state_dict(file_path, torch_dtype=None, device="cpu"):
+def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
- state_dict.update(load_state_dict(file_path_, torch_dtype, device))
- return state_dict
- if file_path.endswith(".safetensors"):
- return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
+ state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
else:
- return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
+ if verbose >= 1:
+ print(f"Loading file [started]: {file_path}")
+ if file_path.endswith(".safetensors"):
+ state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
+ else:
+ state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
+ # If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
+ if pin_memory:
+ for i in state_dict:
+ state_dict[i] = state_dict[i].pin_memory()
+ if verbose >= 1:
+ print(f"Loading file [done]: {file_path}")
+ return state_dict
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
diff --git a/diffsynth/core/loader/model.py b/diffsynth/core/loader/model.py
index 56fa7d3..1f920ab 100644
--- a/diffsynth/core/loader/model.py
+++ b/diffsynth/core/loader/model.py
@@ -5,7 +5,7 @@ from .file import load_state_dict
import torch
-def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
+def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
config = {} if config is None else config
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk":
- state_dict = DiskMap(path, device, torch_dtype=dtype)
+ if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
else:
@@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
# Sometimes a model file contains multiple models,
# and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file.
- if use_disk_map:
+ if state_dict is not None:
+ pass
+ elif use_disk_map:
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else:
state_dict = load_state_dict(path, torch_dtype, device)
diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py
index 723bac0..7d41cac 100644
--- a/diffsynth/diffusion/base_pipeline.py
+++ b/diffsynth/diffusion/base_pipeline.py
@@ -296,6 +296,7 @@ class BasePipeline(torch.nn.Module):
vram_config=vram_config,
vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters,
+ state_dict=model_config.state_dict,
)
return model_pool
diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py
index 9b31466..208fb1e 100644
--- a/diffsynth/diffusion/flow_match.py
+++ b/diffsynth/diffusion/flow_match.py
@@ -4,7 +4,7 @@ from typing_extensions import Literal
class FlowMatchScheduler():
- def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2"] = "FLUX.1"):
+ def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan,
@@ -12,6 +12,7 @@ class FlowMatchScheduler():
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
+ "Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
}.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@@ -71,6 +72,28 @@ class FlowMatchScheduler():
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
+ @staticmethod
+ def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
+ sigma_min = 0.0
+ sigma_max = 1.0
+ num_train_timesteps = 1000
+ base_shift = math.log(3)
+ max_shift = math.log(3)
+ # Sigmas
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
+ # Mu
+ if exponential_shift_mu is not None:
+ mu = exponential_shift_mu
+ elif dynamic_shift_len is not None:
+ mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
+ else:
+ mu = 0.8
+ sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
+ # Timesteps
+ timesteps = sigmas * num_train_timesteps
+ return sigmas, timesteps
+
@staticmethod
def compute_empirical_mu(image_seq_len, num_steps):
a1, b1 = 8.73809524e-05, 1.89833333
diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py
index ae44bb6..14fdfd3 100644
--- a/diffsynth/diffusion/loss.py
+++ b/diffsynth/diffusion/loss.py
@@ -13,9 +13,16 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
+ if "first_frame_latents" in inputs:
+ inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
+
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
+ if "first_frame_latents" in inputs:
+ noise_pred = noise_pred[:, :, 1:]
+ training_target = training_target[:, :, 1:]
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
return loss
diff --git a/diffsynth/models/model_loader.py b/diffsynth/models/model_loader.py
index 16d72dd..6a58c89 100644
--- a/diffsynth/models/model_loader.py
+++ b/diffsynth/models/model_loader.py
@@ -29,7 +29,7 @@ class ModelPool:
module_map = None
return module_map
- def load_model_file(self, config, path, vram_config, vram_limit=None):
+ def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
model_class = self.import_model_class(config["model_class"])
model_config = config.get("extra_kwargs", {})
if "state_dict_converter" in config:
@@ -43,6 +43,7 @@ class ModelPool:
state_dict_converter,
use_disk_map=True,
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
+ state_dict=state_dict,
)
return model
@@ -59,7 +60,7 @@ class ModelPool:
}
return vram_config
- def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False):
+ def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
print(f"Loading models from: {json.dumps(path, indent=4)}")
if vram_config is None:
vram_config = self.default_vram_config()
@@ -67,7 +68,7 @@ class ModelPool:
loaded = False
for config in MODEL_CONFIGS:
if config["model_hash"] == model_hash:
- model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit)
+ model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
if clear_parameters: self.clear_parameters(model)
self.model.append(model)
model_name = config["model_name"]
diff --git a/diffsynth/models/z_image_text_encoder.py b/diffsynth/models/z_image_text_encoder.py
index 4d6271d..6f3e6c0 100644
--- a/diffsynth/models/z_image_text_encoder.py
+++ b/diffsynth/models/z_image_text_encoder.py
@@ -6,6 +6,36 @@ class ZImageTextEncoder(torch.nn.Module):
def __init__(self, model_size="4B"):
super().__init__()
config_dict = {
+ "0.6B": Qwen3Config(**{
+ "architectures": [
+ "Qwen3ForCausalLM"
+ ],
+ "attention_bias": False,
+ "attention_dropout": 0.0,
+ "bos_token_id": 151643,
+ "eos_token_id": 151645,
+ "head_dim": 128,
+ "hidden_act": "silu",
+ "hidden_size": 1024,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "max_position_embeddings": 40960,
+ "max_window_layers": 28,
+ "model_type": "qwen3",
+ "num_attention_heads": 16,
+ "num_hidden_layers": 28,
+ "num_key_value_heads": 8,
+ "rms_norm_eps": 1e-06,
+ "rope_scaling": None,
+ "rope_theta": 1000000,
+ "sliding_window": None,
+ "tie_word_embeddings": True,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.51.0",
+ "use_cache": True,
+ "use_sliding_window": False,
+ "vocab_size": 151936
+ }),
"4B": Qwen3Config(**{
"architectures": [
"Qwen3ForCausalLM"
diff --git a/docs/en/Model_Details/Qwen-Image.md b/docs/en/Model_Details/Qwen-Image.md
index 08b8a35..af3a942 100644
--- a/docs/en/Model_Details/Qwen-Image.md
+++ b/docs/en/Model_Details/Qwen-Image.md
@@ -85,6 +85,7 @@ graph LR;
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
+|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
diff --git a/docs/en/Model_Details/Z-Image.md b/docs/en/Model_Details/Z-Image.md
index 3673a52..677db21 100644
--- a/docs/en/Model_Details/Z-Image.md
+++ b/docs/en/Model_Details/Z-Image.md
@@ -50,9 +50,14 @@ image.save("image.jpg")
## Model Overview
-| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
-| - | - | - | - | - | - | - |
-| [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) | [code](/examples/z_image/model_inference/Z-Image-Turbo.py) | [code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/full/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py) |
+|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
+|-|-|-|-|-|-|-|
+|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
+|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
+|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
+|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
+|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
+|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
Special Training Scripts:
@@ -75,6 +80,9 @@ Input parameters for `ZImagePipeline` inference include:
* `seed`: Random seed. Default is `None`, meaning completely random.
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
* `num_inference_steps`: Number of inference steps, default value is 8.
+* `controlnet_inputs`: Inputs for ControlNet models.
+* `edit_image`: Edit images for image editing models, supporting multiple images.
+* `positive_only_lora`: LoRA weights used only in positive prompts.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
diff --git a/docs/en/README.md b/docs/en/README.md
index 39ae439..e968637 100644
--- a/docs/en/README.md
+++ b/docs/en/README.md
@@ -77,7 +77,7 @@ This section introduces the independent core module `diffsynth.core` in `DiffSyn
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
-* Training models from scratch 【coming soon】
+* [Training models from scratch](/docs/en/Research_Tutorial/train_from_scratch.md)
* Inference improvement techniques 【coming soon】
* Designing controllable generation models 【coming soon】
* Creating new training paradigms 【coming soon】
diff --git a/docs/en/Research_Tutorial/train_from_scratch.md b/docs/en/Research_Tutorial/train_from_scratch.md
new file mode 100644
index 0000000..2a63f82
--- /dev/null
+++ b/docs/en/Research_Tutorial/train_from_scratch.md
@@ -0,0 +1,476 @@
+# Training Models from Scratch
+
+DiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch.
+
+## 1. Building Model Architecture
+
+### 1.1 Diffusion Model
+
+From UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) to DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206), the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include:
+
+* Image tensor (`latents`): The encoding of images, generated by the VAE model, containing partial noise
+* Text tensor (`prompt_embeds`): The encoding of text, generated by the text encoder
+* Timestep (`timestep`): A scalar used to mark which stage of the Diffusion process we are currently at
+
+The model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md). In this article, we build a DiT model with only 0.1B parameters: `AAADiT`.
+
+
+Model Architecture Code
+
+```python
+import torch, accelerate
+from PIL import Image
+from typing import Union
+from tqdm import tqdm
+from einops import rearrange, repeat
+
+from transformers import AutoProcessor, AutoTokenizer
+from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
+from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
+from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
+from diffsynth.models.general_modules import TimestepEmbeddings
+from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
+from diffsynth.models.flux2_vae import Flux2VAE
+
+
+class AAAPositionalEmbedding(torch.nn.Module):
+ def __init__(self, height=16, width=16, dim=1024):
+ super().__init__()
+ self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
+ self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
+
+ def forward(self, image, text):
+ height, width = image.shape[-2:]
+ image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
+ image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
+ image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
+ text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
+ text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
+ emb = torch.concat([image_emb, text_emb], dim=1)
+ return emb
+
+
+class AAABlock(torch.nn.Module):
+ def __init__(self, dim=1024, num_heads=32):
+ super().__init__()
+ self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
+ self.to_q = torch.nn.Linear(dim, dim)
+ self.to_k = torch.nn.Linear(dim, dim)
+ self.to_v = torch.nn.Linear(dim, dim)
+ self.to_out = torch.nn.Linear(dim, dim)
+ self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
+ self.ff = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim*3),
+ torch.nn.SiLU(),
+ torch.nn.Linear(dim*3, dim),
+ )
+ self.to_gate = torch.nn.Linear(dim, dim * 2)
+ self.num_heads = num_heads
+
+ def attention(self, emb, pos_emb):
+ emb = self.norm_attn(emb + pos_emb)
+ q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
+ emb = attention_forward(
+ q, k, v,
+ q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
+ dims={"n": self.num_heads},
+ )
+ emb = self.to_out(emb)
+ return emb
+
+ def feed_forward(self, emb, pos_emb):
+ emb = self.norm_mlp(emb + pos_emb)
+ emb = self.ff(emb)
+ return emb
+
+ def forward(self, emb, pos_emb, t_emb):
+ gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
+ emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
+ emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
+ return emb
+
+
+class AAADiT(torch.nn.Module):
+ def __init__(self, dim=1024):
+ super().__init__()
+ self.pos_embedder = AAAPositionalEmbedding(dim=dim)
+ self.timestep_embedder = TimestepEmbeddings(256, dim)
+ self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
+ self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
+ self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
+ self.proj_out = torch.nn.Linear(dim, 128)
+
+ def forward(
+ self,
+ latents,
+ prompt_embeds,
+ timestep,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ ):
+ pos_emb = self.pos_embedder(latents, prompt_embeds)
+ t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
+ image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
+ text = self.text_embedder(prompt_embeds)
+ emb = torch.concat([image, text], dim=1)
+ for block_id, block in enumerate(self.blocks):
+ emb = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ emb=emb,
+ pos_emb=pos_emb,
+ t_emb=t_emb,
+ )
+ emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
+ emb = self.proj_out(emb)
+ emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
+ return emb
+```
+
+
+
+### 1.2 Encoder-Decoder Models
+
+Besides the Diffusion model used for denoising, we also need two other models:
+
+* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model.
+* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B).
+
+The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py), so we don't need to modify any code.
+
+## 2. Building Pipeline
+
+We introduced how to build a model Pipeline in the document [Integrating Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md). For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder.
+
+
+Pipeline Code
+
+```python
+class AAAImagePipeline(BasePipeline):
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
+ super().__init__(
+ device=device, torch_dtype=torch_dtype,
+ height_division_factor=16, width_division_factor=16,
+ )
+ self.scheduler = FlowMatchScheduler("FLUX.2")
+ self.text_encoder: ZImageTextEncoder = None
+ self.dit: AAADiT = None
+ self.vae: Flux2VAE = None
+ self.tokenizer: AutoProcessor = None
+ self.in_iteration_models = ("dit",)
+ self.units = [
+ AAAUnit_PromptEmbedder(),
+ AAAUnit_NoiseInitializer(),
+ AAAUnit_InputImageEmbedder(),
+ ]
+ self.model_fn = model_fn_aaa
+
+ @staticmethod
+ def from_pretrained(
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: Union[str, torch.device] = "cuda",
+ model_configs: list[ModelConfig] = [],
+ tokenizer_config: ModelConfig = None,
+ vram_limit: float = None,
+ ):
+ # Initialize pipeline
+ pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
+ model_pool = pipe.download_and_load_models(model_configs, vram_limit)
+
+ # Fetch models
+ pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
+ pipe.dit = model_pool.fetch_model("aaa_dit")
+ pipe.vae = model_pool.fetch_model("flux2_vae")
+ if tokenizer_config is not None:
+ tokenizer_config.download_if_necessary()
+ pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
+
+ # VRAM Management
+ pipe.vram_management_enabled = pipe.check_vram_management_state()
+ return pipe
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ # Prompt
+ prompt: str,
+ negative_prompt: str = "",
+ cfg_scale: float = 1.0,
+ # Image
+ input_image: Image.Image = None,
+ denoising_strength: float = 1.0,
+ # Shape
+ height: int = 1024,
+ width: int = 1024,
+ # Randomness
+ seed: int = None,
+ rand_device: str = "cpu",
+ # Steps
+ num_inference_steps: int = 30,
+ # Progress bar
+ progress_bar_cmd = tqdm,
+ ):
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
+
+ # Parameters
+ inputs_posi = {"prompt": prompt}
+ inputs_nega = {"negative_prompt": negative_prompt}
+ inputs_shared = {
+ "cfg_scale": cfg_scale,
+ "input_image": input_image, "denoising_strength": denoising_strength,
+ "height": height, "width": width,
+ "seed": seed, "rand_device": rand_device,
+ "num_inference_steps": num_inference_steps,
+ }
+ for unit in self.units:
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+
+ # Denoise
+ self.load_models_to_device(self.in_iteration_models)
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+ noise_pred = self.cfg_guided_model_fn(
+ self.model_fn, cfg_scale,
+ inputs_shared, inputs_posi, inputs_nega,
+ **models, timestep=timestep, progress_id=progress_id
+ )
+ inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ image = self.vae.decode(inputs_shared["latents"])
+ image = self.vae_output_to_image(image)
+ self.load_models_to_device([])
+
+ return image
+
+
+class AAAUnit_PromptEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={"prompt": "prompt"},
+ input_params_nega={"prompt": "negative_prompt"},
+ output_params=("prompt_embeds",),
+ onload_model_names=("text_encoder",)
+ )
+ self.hidden_states_layers = (-1,)
+
+ def process(self, pipe: AAAImagePipeline, prompt):
+ pipe.load_models_to_device(self.onload_model_names)
+ text = pipe.tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt}],
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=False,
+ )
+ inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
+ output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
+ prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
+ return {"prompt_embeds": prompt_embeds}
+
+
+class AAAUnit_NoiseInitializer(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("height", "width", "seed", "rand_device"),
+ output_params=("noise",),
+ )
+
+ def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
+ noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
+ return {"noise": noise}
+
+
+class AAAUnit_InputImageEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("input_image", "noise"),
+ output_params=("latents", "input_latents"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: AAAImagePipeline, input_image, noise):
+ if input_image is None:
+ return {"latents": noise, "input_latents": None}
+ pipe.load_models_to_device(['vae'])
+ image = pipe.preprocess_image(input_image)
+ input_latents = pipe.vae.encode(image)
+ if pipe.scheduler.training:
+ return {"latents": noise, "input_latents": input_latents}
+ else:
+ latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
+ return {"latents": latents, "input_latents": input_latents}
+
+
+def model_fn_aaa(
+ dit: AAADiT,
+ latents=None,
+ prompt_embeds=None,
+ timestep=None,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ **kwargs,
+):
+ model_output = dit(
+ latents,
+ prompt_embeds,
+ timestep,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ )
+ return model_output
+```
+
+
+
+## 3. Preparing Dataset
+
+To quickly verify training effectiveness, we use the dataset [Pokemon-First Generation](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1), which is reproduced from the open-source project [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh), containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document [Preparing Datasets](/docs/en/Pipeline_Usage/Model_Training.md#preparing-datasets) and [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md).
+
+```shell
+modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
+```
+
+### 4. Start Training
+
+The training process can be quickly implemented using Pipeline. We have placed the complete code at [/docs/en/Research_Tutorial/train_from_scratch.py](/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training.
+
+To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training.
+
+This training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training.
+
+
+Training Code
+
+```python
+class AAATrainingModule(DiffusionTrainingModule):
+ def __init__(self, device):
+ super().__init__()
+ self.pipe = AAAImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device=device,
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ )
+ self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
+ self.pipe.freeze_except(["dit"])
+ self.pipe.scheduler.set_timesteps(1000, training=True)
+
+ def forward(self, data):
+ inputs_posi = {"prompt": data["prompt"]}
+ inputs_nega = {"negative_prompt": ""}
+ inputs_shared = {
+ "input_image": data["image"],
+ "height": data["image"].size[1],
+ "width": data["image"].size[0],
+ "cfg_scale": 1,
+ "use_gradient_checkpointing": False,
+ "use_gradient_checkpointing_offload": False,
+ }
+ for unit in self.pipe.units:
+ inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
+ loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
+ return loss
+
+
+if __name__ == "__main__":
+ accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
+ dataset = UnifiedDataset(
+ base_path="data/images",
+ metadata_path="data/metadata_merged.csv",
+ max_data_items=10000000,
+ data_file_keys=("image",),
+ main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
+ )
+ model = AAATrainingModule(device=accelerator.device)
+ model_logger = ModelLogger(
+ "models/AAA/v1",
+ remove_prefix_in_ckpt="pipe.dit.",
+ )
+ launch_training_task(
+ accelerator, dataset, model, model_logger,
+ learning_rate=2e-4,
+ num_workers=4,
+ save_steps=50000,
+ num_epochs=999999,
+ )
+```
+
+
+
+## 5. Verifying Training Results
+
+If you don't want to wait for the model training to complete, you can directly download [our pre-trained model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel).
+
+```shell
+modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel
+```
+
+Loading the model
+
+```python
+from diffsynth import load_model
+
+pipe = AAAImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+)
+pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")
+```
+
+Model inference, generating the first-generation Pokemon "starter trio". At this point, the images generated by the model basically match the training data.
+
+```python
+for seed, prompt in enumerate([
+ "green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
+ "orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
+ "blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail",
+]):
+ image = pipe(
+ prompt=prompt,
+ negative_prompt=" ",
+ num_inference_steps=30,
+ cfg_scale=10,
+ seed=seed,
+ height=256, width=256,
+ )
+ image.save(f"image_{seed}.jpg")
+```
+
+||||
+|-|-|-|
+
+Model inference, generating Pokemon with "sharp claws". At this point, different random seeds can produce different image results.
+
+```python
+for seed, prompt in enumerate([
+ "sharp claws",
+ "sharp claws",
+ "sharp claws",
+]):
+ image = pipe(
+ prompt=prompt,
+ negative_prompt=" ",
+ num_inference_steps=30,
+ cfg_scale=10,
+ seed=seed+4,
+ height=256, width=256,
+ )
+ image.save(f"image_sharp_claws_{seed}.jpg")
+```
+
+||||
+|-|-|-|
+
+Now, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model!
\ No newline at end of file
diff --git a/docs/en/Research_Tutorial/train_from_scratch.py b/docs/en/Research_Tutorial/train_from_scratch.py
new file mode 100644
index 0000000..328c24d
--- /dev/null
+++ b/docs/en/Research_Tutorial/train_from_scratch.py
@@ -0,0 +1,341 @@
+import torch, accelerate
+from PIL import Image
+from typing import Union
+from tqdm import tqdm
+from einops import rearrange, repeat
+
+from transformers import AutoProcessor, AutoTokenizer
+from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
+from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
+from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
+from diffsynth.models.general_modules import TimestepEmbeddings
+from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
+from diffsynth.models.flux2_vae import Flux2VAE
+
+
+class AAAPositionalEmbedding(torch.nn.Module):
+ def __init__(self, height=16, width=16, dim=1024):
+ super().__init__()
+ self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
+ self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
+
+ def forward(self, image, text):
+ height, width = image.shape[-2:]
+ image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
+ image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
+ image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
+ text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
+ text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
+ emb = torch.concat([image_emb, text_emb], dim=1)
+ return emb
+
+
+class AAABlock(torch.nn.Module):
+ def __init__(self, dim=1024, num_heads=32):
+ super().__init__()
+ self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
+ self.to_q = torch.nn.Linear(dim, dim)
+ self.to_k = torch.nn.Linear(dim, dim)
+ self.to_v = torch.nn.Linear(dim, dim)
+ self.to_out = torch.nn.Linear(dim, dim)
+ self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
+ self.ff = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim*3),
+ torch.nn.SiLU(),
+ torch.nn.Linear(dim*3, dim),
+ )
+ self.to_gate = torch.nn.Linear(dim, dim * 2)
+ self.num_heads = num_heads
+
+ def attention(self, emb, pos_emb):
+ emb = self.norm_attn(emb + pos_emb)
+ q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
+ emb = attention_forward(
+ q, k, v,
+ q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
+ dims={"n": self.num_heads},
+ )
+ emb = self.to_out(emb)
+ return emb
+
+ def feed_forward(self, emb, pos_emb):
+ emb = self.norm_mlp(emb + pos_emb)
+ emb = self.ff(emb)
+ return emb
+
+ def forward(self, emb, pos_emb, t_emb):
+ gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
+ emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
+ emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
+ return emb
+
+
+class AAADiT(torch.nn.Module):
+ def __init__(self, dim=1024):
+ super().__init__()
+ self.pos_embedder = AAAPositionalEmbedding(dim=dim)
+ self.timestep_embedder = TimestepEmbeddings(256, dim)
+ self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
+ self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
+ self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
+ self.proj_out = torch.nn.Linear(dim, 128)
+
+ def forward(
+ self,
+ latents,
+ prompt_embeds,
+ timestep,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ ):
+ pos_emb = self.pos_embedder(latents, prompt_embeds)
+ t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
+ image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
+ text = self.text_embedder(prompt_embeds)
+ emb = torch.concat([image, text], dim=1)
+ for block_id, block in enumerate(self.blocks):
+ emb = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ emb=emb,
+ pos_emb=pos_emb,
+ t_emb=t_emb,
+ )
+ emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
+ emb = self.proj_out(emb)
+ emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
+ return emb
+
+
+class AAAImagePipeline(BasePipeline):
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
+ super().__init__(
+ device=device, torch_dtype=torch_dtype,
+ height_division_factor=16, width_division_factor=16,
+ )
+ self.scheduler = FlowMatchScheduler("FLUX.2")
+ self.text_encoder: ZImageTextEncoder = None
+ self.dit: AAADiT = None
+ self.vae: Flux2VAE = None
+ self.tokenizer: AutoProcessor = None
+ self.in_iteration_models = ("dit",)
+ self.units = [
+ AAAUnit_PromptEmbedder(),
+ AAAUnit_NoiseInitializer(),
+ AAAUnit_InputImageEmbedder(),
+ ]
+ self.model_fn = model_fn_aaa
+
+ @staticmethod
+ def from_pretrained(
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: Union[str, torch.device] = "cuda",
+ model_configs: list[ModelConfig] = [],
+ tokenizer_config: ModelConfig = None,
+ vram_limit: float = None,
+ ):
+ # Initialize pipeline
+ pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
+ model_pool = pipe.download_and_load_models(model_configs, vram_limit)
+
+ # Fetch models
+ pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
+ pipe.dit = model_pool.fetch_model("aaa_dit")
+ pipe.vae = model_pool.fetch_model("flux2_vae")
+ if tokenizer_config is not None:
+ tokenizer_config.download_if_necessary()
+ pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
+
+ # VRAM Management
+ pipe.vram_management_enabled = pipe.check_vram_management_state()
+ return pipe
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ # Prompt
+ prompt: str,
+ negative_prompt: str = "",
+ cfg_scale: float = 1.0,
+ # Image
+ input_image: Image.Image = None,
+ denoising_strength: float = 1.0,
+ # Shape
+ height: int = 1024,
+ width: int = 1024,
+ # Randomness
+ seed: int = None,
+ rand_device: str = "cpu",
+ # Steps
+ num_inference_steps: int = 30,
+ # Progress bar
+ progress_bar_cmd = tqdm,
+ ):
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
+
+ # Parameters
+ inputs_posi = {"prompt": prompt}
+ inputs_nega = {"negative_prompt": negative_prompt}
+ inputs_shared = {
+ "cfg_scale": cfg_scale,
+ "input_image": input_image, "denoising_strength": denoising_strength,
+ "height": height, "width": width,
+ "seed": seed, "rand_device": rand_device,
+ "num_inference_steps": num_inference_steps,
+ }
+ for unit in self.units:
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+
+ # Denoise
+ self.load_models_to_device(self.in_iteration_models)
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+ noise_pred = self.cfg_guided_model_fn(
+ self.model_fn, cfg_scale,
+ inputs_shared, inputs_posi, inputs_nega,
+ **models, timestep=timestep, progress_id=progress_id
+ )
+ inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ image = self.vae.decode(inputs_shared["latents"])
+ image = self.vae_output_to_image(image)
+ self.load_models_to_device([])
+
+ return image
+
+
+class AAAUnit_PromptEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={"prompt": "prompt"},
+ input_params_nega={"prompt": "negative_prompt"},
+ output_params=("prompt_embeds",),
+ onload_model_names=("text_encoder",)
+ )
+ self.hidden_states_layers = (-1,)
+
+ def process(self, pipe: AAAImagePipeline, prompt):
+ pipe.load_models_to_device(self.onload_model_names)
+ text = pipe.tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt}],
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=False,
+ )
+ inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
+ output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
+ prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
+ return {"prompt_embeds": prompt_embeds}
+
+
+class AAAUnit_NoiseInitializer(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("height", "width", "seed", "rand_device"),
+ output_params=("noise",),
+ )
+
+ def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
+ noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
+ return {"noise": noise}
+
+
+class AAAUnit_InputImageEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("input_image", "noise"),
+ output_params=("latents", "input_latents"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: AAAImagePipeline, input_image, noise):
+ if input_image is None:
+ return {"latents": noise, "input_latents": None}
+ pipe.load_models_to_device(['vae'])
+ image = pipe.preprocess_image(input_image)
+ input_latents = pipe.vae.encode(image)
+ if pipe.scheduler.training:
+ return {"latents": noise, "input_latents": input_latents}
+ else:
+ latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
+ return {"latents": latents, "input_latents": input_latents}
+
+
+def model_fn_aaa(
+ dit: AAADiT,
+ latents=None,
+ prompt_embeds=None,
+ timestep=None,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ **kwargs,
+):
+ model_output = dit(
+ latents,
+ prompt_embeds,
+ timestep,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ )
+ return model_output
+
+
+class AAATrainingModule(DiffusionTrainingModule):
+ def __init__(self, device):
+ super().__init__()
+ self.pipe = AAAImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device=device,
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ )
+ self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
+ self.pipe.freeze_except(["dit"])
+ self.pipe.scheduler.set_timesteps(1000, training=True)
+
+ def forward(self, data):
+ inputs_posi = {"prompt": data["prompt"]}
+ inputs_nega = {"negative_prompt": ""}
+ inputs_shared = {
+ "input_image": data["image"],
+ "height": data["image"].size[1],
+ "width": data["image"].size[0],
+ "cfg_scale": 1,
+ "use_gradient_checkpointing": False,
+ "use_gradient_checkpointing_offload": False,
+ }
+ for unit in self.pipe.units:
+ inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
+ loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
+ return loss
+
+
+if __name__ == "__main__":
+ accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
+ dataset = UnifiedDataset(
+ base_path="data/images",
+ metadata_path="data/metadata_merged.csv",
+ max_data_items=10000000,
+ data_file_keys=("image",),
+ main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
+ )
+ model = AAATrainingModule(device=accelerator.device)
+ model_logger = ModelLogger(
+ "models/AAA/v1",
+ remove_prefix_in_ckpt="pipe.dit.",
+ )
+ launch_training_task(
+ accelerator, dataset, model, model_logger,
+ learning_rate=2e-4,
+ num_workers=4,
+ save_steps=50000,
+ num_epochs=999999,
+ )
\ No newline at end of file
diff --git a/docs/en/Training/Understanding_Diffusion_models.md b/docs/en/Training/Understanding_Diffusion_models.md
index 5c81b6a..718df25 100644
--- a/docs/en/Training/Understanding_Diffusion_models.md
+++ b/docs/en/Training/Understanding_Diffusion_models.md
@@ -6,7 +6,7 @@ This document introduces the basic principles of Diffusion models to help you un
Diffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$.
-(Figure)
+
This process is intuitive, but to understand the details, we need to answer several questions:
@@ -28,7 +28,7 @@ As for the intermediate values $\sigma_{T-1}$, $\sigma_{T-2}$, $\cdots$, $\sigma
At an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\sigma_t)x_0+\sigma_t x_T$.
-(Figure)
+
## How is the iterative denoising computation performed?
@@ -40,8 +40,6 @@ Before understanding the iterative denoising computation, we need to clarify wha
Among these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure.
-(Figure)
-
The model's output $\hat \epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising).
Next, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$:
@@ -91,8 +89,6 @@ After understanding the iterative denoising process, we next consider how to tra
The training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training.
-(Figure)
-
The following is pseudocode for the training process:
> Obtain data sample $x_0$ and guidance condition $c$ from the dataset
@@ -113,7 +109,7 @@ The following is pseudocode for the training process:
From theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the "three-stage" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model.
-(Figure)
+
### Data Encoder-Decoder
diff --git a/docs/zh/Model_Details/Qwen-Image.md b/docs/zh/Model_Details/Qwen-Image.md
index 697438f..18e7d02 100644
--- a/docs/zh/Model_Details/Qwen-Image.md
+++ b/docs/zh/Model_Details/Qwen-Image.md
@@ -85,6 +85,7 @@ graph LR;
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
+|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
diff --git a/docs/zh/Model_Details/Z-Image.md b/docs/zh/Model_Details/Z-Image.md
index c51083a..dd8b3fd 100644
--- a/docs/zh/Model_Details/Z-Image.md
+++ b/docs/zh/Model_Details/Z-Image.md
@@ -52,7 +52,12 @@ image.save("image.jpg")
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
+|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
+|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
+|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
+|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
+|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
特殊训练脚本:
@@ -75,6 +80,9 @@ image.save("image.jpg")
* `seed`: 随机种子。默认为 `None`,即完全随机。
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
* `num_inference_steps`: 推理次数,默认值为 8。
+* `controlnet_inputs`: ControlNet 模型的输入。
+* `edit_image`: 编辑模型的待编辑图像,支持多张图像。
+* `positive_only_lora`: 仅在正向提示词中使用的 LoRA 权重。
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
diff --git a/docs/zh/README.md b/docs/zh/README.md
index edcef50..c02665f 100644
--- a/docs/zh/README.md
+++ b/docs/zh/README.md
@@ -77,7 +77,7 @@ graph LR;
本节介绍如何利用 `DiffSynth-Studio` 训练新的模型,帮助科研工作者探索新的模型技术。
-* 从零开始训练模型【coming soon】
+* [从零开始训练模型](/docs/zh/Research_Tutorial/train_from_scratch.md)
* 推理改进优化技术【coming soon】
* 设计可控生成模型【coming soon】
* 创建新的训练范式【coming soon】
diff --git a/docs/zh/Research_Tutorial/train_from_scratch.md b/docs/zh/Research_Tutorial/train_from_scratch.md
new file mode 100644
index 0000000..2c620eb
--- /dev/null
+++ b/docs/zh/Research_Tutorial/train_from_scratch.md
@@ -0,0 +1,477 @@
+# 从零开始训练模型
+
+DiffSynth-Studio 的训练引擎支持从零开始训练基础模型,本文介绍如何从零开始训练一个参数量仅为 0.1B 的小型文生图模型。
+
+## 1. 构建模型结构
+
+### 1.1 Diffusion 模型
+
+从 UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) 到 DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206),Diffusion 的主流模型结构经历了多次演变。通常,一个 Diffusion 模型的输入包括:
+
+* 图像张量(`latents`):图像的编码,由 VAE 模型产生,含有部分噪声
+* 文本张量(`prompt_embeds`):文本的编码,由文本编码器产生
+* 时间步(`timestep`):标量,用于标记当前处于 Diffusion 过程的哪个阶段
+
+模型的输出是与图像张量形状相同的张量,表示模型预测的去噪方向,关于 Diffusion 模型理论的细节,请参考 [Diffusion 模型基本原理](/docs/zh/Training/Understanding_Diffusion_models.md)。在本文中,我们构建一个仅含 0.1B 参数的 DiT 模型:`AAADiT`。
+
+
+模型结构代码
+
+```python
+import torch, accelerate
+from PIL import Image
+from typing import Union
+from tqdm import tqdm
+from einops import rearrange, repeat
+
+from transformers import AutoProcessor, AutoTokenizer
+from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
+from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
+from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
+from diffsynth.models.general_modules import TimestepEmbeddings
+from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
+from diffsynth.models.flux2_vae import Flux2VAE
+
+
+class AAAPositionalEmbedding(torch.nn.Module):
+ def __init__(self, height=16, width=16, dim=1024):
+ super().__init__()
+ self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
+ self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
+
+ def forward(self, image, text):
+ height, width = image.shape[-2:]
+ image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
+ image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
+ image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
+ text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
+ text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
+ emb = torch.concat([image_emb, text_emb], dim=1)
+ return emb
+
+
+class AAABlock(torch.nn.Module):
+ def __init__(self, dim=1024, num_heads=32):
+ super().__init__()
+ self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
+ self.to_q = torch.nn.Linear(dim, dim)
+ self.to_k = torch.nn.Linear(dim, dim)
+ self.to_v = torch.nn.Linear(dim, dim)
+ self.to_out = torch.nn.Linear(dim, dim)
+ self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
+ self.ff = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim*3),
+ torch.nn.SiLU(),
+ torch.nn.Linear(dim*3, dim),
+ )
+ self.to_gate = torch.nn.Linear(dim, dim * 2)
+ self.num_heads = num_heads
+
+ def attention(self, emb, pos_emb):
+ emb = self.norm_attn(emb + pos_emb)
+ q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
+ emb = attention_forward(
+ q, k, v,
+ q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
+ dims={"n": self.num_heads},
+ )
+ emb = self.to_out(emb)
+ return emb
+
+ def feed_forward(self, emb, pos_emb):
+ emb = self.norm_mlp(emb + pos_emb)
+ emb = self.ff(emb)
+ return emb
+
+ def forward(self, emb, pos_emb, t_emb):
+ gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
+ emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
+ emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
+ return emb
+
+
+class AAADiT(torch.nn.Module):
+ def __init__(self, dim=1024):
+ super().__init__()
+ self.pos_embedder = AAAPositionalEmbedding(dim=dim)
+ self.timestep_embedder = TimestepEmbeddings(256, dim)
+ self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
+ self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
+ self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
+ self.proj_out = torch.nn.Linear(dim, 128)
+
+ def forward(
+ self,
+ latents,
+ prompt_embeds,
+ timestep,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ ):
+ pos_emb = self.pos_embedder(latents, prompt_embeds)
+ t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
+ image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
+ text = self.text_embedder(prompt_embeds)
+ emb = torch.concat([image, text], dim=1)
+ for block_id, block in enumerate(self.blocks):
+ emb = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ emb=emb,
+ pos_emb=pos_emb,
+ t_emb=t_emb,
+ )
+ emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
+ emb = self.proj_out(emb)
+ emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
+ return emb
+```
+
+
+
+### 1.2 编解码器模型
+
+除了用于去噪的 Diffusion 模型以外,我们还需要另外两个模型:
+
+* 文本编码器:用于将文本编码为张量。我们采用 [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) 模型。
+* VAE 编解码器:编码器部分用于将图像编码为张量,解码器部分用于将图像张量解码为图像。我们采用 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 中的 VAE 模型。
+
+这两个模型的结构都已集成在 DiffSynth-Studio 中,分别位于 [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) 和 [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py),因此我们不需要修改任何代码。
+
+## 2. 构建 Pipeline
+
+我们在文档 [接入 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md) 中介绍了如何构建一个模型 Pipeline,对于本文中的模型,我们也需要构建一个 Pipeline,连接文本编码器、Diffusion 模型、VAE 编解码器。
+
+
+Pipeline 代码
+
+```python
+class AAAImagePipeline(BasePipeline):
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
+ super().__init__(
+ device=device, torch_dtype=torch_dtype,
+ height_division_factor=16, width_division_factor=16,
+ )
+ self.scheduler = FlowMatchScheduler("FLUX.2")
+ self.text_encoder: ZImageTextEncoder = None
+ self.dit: AAADiT = None
+ self.vae: Flux2VAE = None
+ self.tokenizer: AutoProcessor = None
+ self.in_iteration_models = ("dit",)
+ self.units = [
+ AAAUnit_PromptEmbedder(),
+ AAAUnit_NoiseInitializer(),
+ AAAUnit_InputImageEmbedder(),
+ ]
+ self.model_fn = model_fn_aaa
+
+ @staticmethod
+ def from_pretrained(
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: Union[str, torch.device] = "cuda",
+ model_configs: list[ModelConfig] = [],
+ tokenizer_config: ModelConfig = None,
+ vram_limit: float = None,
+ ):
+ # Initialize pipeline
+ pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
+ model_pool = pipe.download_and_load_models(model_configs, vram_limit)
+
+ # Fetch models
+ pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
+ pipe.dit = model_pool.fetch_model("aaa_dit")
+ pipe.vae = model_pool.fetch_model("flux2_vae")
+ if tokenizer_config is not None:
+ tokenizer_config.download_if_necessary()
+ pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
+
+ # VRAM Management
+ pipe.vram_management_enabled = pipe.check_vram_management_state()
+ return pipe
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ # Prompt
+ prompt: str,
+ negative_prompt: str = "",
+ cfg_scale: float = 1.0,
+ # Image
+ input_image: Image.Image = None,
+ denoising_strength: float = 1.0,
+ # Shape
+ height: int = 1024,
+ width: int = 1024,
+ # Randomness
+ seed: int = None,
+ rand_device: str = "cpu",
+ # Steps
+ num_inference_steps: int = 30,
+ # Progress bar
+ progress_bar_cmd = tqdm,
+ ):
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
+
+ # Parameters
+ inputs_posi = {"prompt": prompt}
+ inputs_nega = {"negative_prompt": negative_prompt}
+ inputs_shared = {
+ "cfg_scale": cfg_scale,
+ "input_image": input_image, "denoising_strength": denoising_strength,
+ "height": height, "width": width,
+ "seed": seed, "rand_device": rand_device,
+ "num_inference_steps": num_inference_steps,
+ }
+ for unit in self.units:
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+
+ # Denoise
+ self.load_models_to_device(self.in_iteration_models)
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+ noise_pred = self.cfg_guided_model_fn(
+ self.model_fn, cfg_scale,
+ inputs_shared, inputs_posi, inputs_nega,
+ **models, timestep=timestep, progress_id=progress_id
+ )
+ inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ image = self.vae.decode(inputs_shared["latents"])
+ image = self.vae_output_to_image(image)
+ self.load_models_to_device([])
+
+ return image
+
+
+class AAAUnit_PromptEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={"prompt": "prompt"},
+ input_params_nega={"prompt": "negative_prompt"},
+ output_params=("prompt_embeds",),
+ onload_model_names=("text_encoder",)
+ )
+ self.hidden_states_layers = (-1,)
+
+ def process(self, pipe: AAAImagePipeline, prompt):
+ pipe.load_models_to_device(self.onload_model_names)
+ text = pipe.tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt}],
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=False,
+ )
+ inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
+ output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
+ prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
+ return {"prompt_embeds": prompt_embeds}
+
+
+class AAAUnit_NoiseInitializer(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("height", "width", "seed", "rand_device"),
+ output_params=("noise",),
+ )
+
+ def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
+ noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
+ return {"noise": noise}
+
+
+class AAAUnit_InputImageEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("input_image", "noise"),
+ output_params=("latents", "input_latents"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: AAAImagePipeline, input_image, noise):
+ if input_image is None:
+ return {"latents": noise, "input_latents": None}
+ pipe.load_models_to_device(['vae'])
+ image = pipe.preprocess_image(input_image)
+ input_latents = pipe.vae.encode(image)
+ if pipe.scheduler.training:
+ return {"latents": noise, "input_latents": input_latents}
+ else:
+ latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
+ return {"latents": latents, "input_latents": input_latents}
+
+
+def model_fn_aaa(
+ dit: AAADiT,
+ latents=None,
+ prompt_embeds=None,
+ timestep=None,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ **kwargs,
+):
+ model_output = dit(
+ latents,
+ prompt_embeds,
+ timestep,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ )
+ return model_output
+```
+
+
+
+## 3. 准备数据集
+
+为了快速验证训练效果,我们使用数据集 [宝可梦-第一世代](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1),这个数据集转载自开源项目 [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh),包含从妙蛙种子到梦幻的 151 个第一世代宝可梦。如果你想使用其他数据集,请参考文档 [准备数据集](/docs/zh/Pipeline_Usage/Model_Training.md#准备数据集) 和 [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md)。
+
+```shell
+modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
+```
+
+### 4. 开始训练
+
+训练过程可使用 Pipeline 快速实现,我们已将完整的代码放在 [/docs/zh/Research_Tutorial/train_from_scratch.py](/docs/zh/Research_Tutorial/train_from_scratch.py),可直接通过 `python docs/zh/Research_Tutorial/train_from_scratch.py` 开始单 GPU 训练。
+
+如需开启多 GPU 并行训练,请运行 `accelerate config` 设置相关参数,然后使用命令 `accelerate launch docs/zh/Research_Tutorial/train_from_scratch.py` 开始训练。
+
+这个训练脚本没有设置停止条件,请在需要时手动关闭。模型在训练大约 6 万步后收敛,单 GPU 训练需要 10~20 小时。
+
+
+
+训练代码
+
+```python
+class AAATrainingModule(DiffusionTrainingModule):
+ def __init__(self, device):
+ super().__init__()
+ self.pipe = AAAImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device=device,
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ )
+ self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
+ self.pipe.freeze_except(["dit"])
+ self.pipe.scheduler.set_timesteps(1000, training=True)
+
+ def forward(self, data):
+ inputs_posi = {"prompt": data["prompt"]}
+ inputs_nega = {"negative_prompt": ""}
+ inputs_shared = {
+ "input_image": data["image"],
+ "height": data["image"].size[1],
+ "width": data["image"].size[0],
+ "cfg_scale": 1,
+ "use_gradient_checkpointing": False,
+ "use_gradient_checkpointing_offload": False,
+ }
+ for unit in self.pipe.units:
+ inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
+ loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
+ return loss
+
+
+if __name__ == "__main__":
+ accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
+ dataset = UnifiedDataset(
+ base_path="data/images",
+ metadata_path="data/metadata_merged.csv",
+ max_data_items=10000000,
+ data_file_keys=("image",),
+ main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
+ )
+ model = AAATrainingModule(device=accelerator.device)
+ model_logger = ModelLogger(
+ "models/AAA/v1",
+ remove_prefix_in_ckpt="pipe.dit.",
+ )
+ launch_training_task(
+ accelerator, dataset, model, model_logger,
+ learning_rate=2e-4,
+ num_workers=4,
+ save_steps=50000,
+ num_epochs=999999,
+ )
+```
+
+
+
+## 5. 验证训练效果
+
+如果你不想等待模型训练完成,可以直接下载[我们预先训练好的模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel)。
+
+```shell
+modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel
+```
+
+加载模型
+
+```python
+from diffsynth import load_model
+
+pipe = AAAImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+)
+pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")
+```
+
+模型推理,生成第一世代宝可梦“御三家”,此时模型生成的图像内容与训练数据基本一致。
+
+```python
+for seed, prompt in enumerate([
+ "green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
+ "orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
+ "蓝色,米色,棕色,乌龟,水系,龟壳,大眼睛,短四肢,卷曲尾巴",
+]):
+ image = pipe(
+ prompt=prompt,
+ negative_prompt=" ",
+ num_inference_steps=30,
+ cfg_scale=10,
+ seed=seed,
+ height=256, width=256,
+ )
+ image.save(f"image_{seed}.jpg")
+```
+
+||||
+|-|-|-|
+
+模型推理,生成具有“锐利爪子”的宝可梦,此时不同的随机种子能够产生不同的图像结果。
+
+```python
+for seed, prompt in enumerate([
+ "sharp claws",
+ "sharp claws",
+ "sharp claws",
+]):
+ image = pipe(
+ prompt=prompt,
+ negative_prompt=" ",
+ num_inference_steps=30,
+ cfg_scale=10,
+ seed=seed+4,
+ height=256, width=256,
+ )
+ image.save(f"image_sharp_claws_{seed}.jpg")
+```
+
+||||
+|-|-|-|
+
+现在,我们获得了一个 0.1B 的小型文生图模型,这个模型已经能够生成 151 个宝可梦,但无法生成其他图像内容。如果在此基础上增加数据量、模型参数量、GPU 数量,你就可以训练出一个更强大的文生图模型!
diff --git a/docs/zh/Research_Tutorial/train_from_scratch.py b/docs/zh/Research_Tutorial/train_from_scratch.py
new file mode 100644
index 0000000..328c24d
--- /dev/null
+++ b/docs/zh/Research_Tutorial/train_from_scratch.py
@@ -0,0 +1,341 @@
+import torch, accelerate
+from PIL import Image
+from typing import Union
+from tqdm import tqdm
+from einops import rearrange, repeat
+
+from transformers import AutoProcessor, AutoTokenizer
+from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
+from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
+from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
+from diffsynth.models.general_modules import TimestepEmbeddings
+from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
+from diffsynth.models.flux2_vae import Flux2VAE
+
+
+class AAAPositionalEmbedding(torch.nn.Module):
+ def __init__(self, height=16, width=16, dim=1024):
+ super().__init__()
+ self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
+ self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
+
+ def forward(self, image, text):
+ height, width = image.shape[-2:]
+ image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
+ image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
+ image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
+ text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
+ text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
+ emb = torch.concat([image_emb, text_emb], dim=1)
+ return emb
+
+
+class AAABlock(torch.nn.Module):
+ def __init__(self, dim=1024, num_heads=32):
+ super().__init__()
+ self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
+ self.to_q = torch.nn.Linear(dim, dim)
+ self.to_k = torch.nn.Linear(dim, dim)
+ self.to_v = torch.nn.Linear(dim, dim)
+ self.to_out = torch.nn.Linear(dim, dim)
+ self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
+ self.ff = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim*3),
+ torch.nn.SiLU(),
+ torch.nn.Linear(dim*3, dim),
+ )
+ self.to_gate = torch.nn.Linear(dim, dim * 2)
+ self.num_heads = num_heads
+
+ def attention(self, emb, pos_emb):
+ emb = self.norm_attn(emb + pos_emb)
+ q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
+ emb = attention_forward(
+ q, k, v,
+ q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
+ dims={"n": self.num_heads},
+ )
+ emb = self.to_out(emb)
+ return emb
+
+ def feed_forward(self, emb, pos_emb):
+ emb = self.norm_mlp(emb + pos_emb)
+ emb = self.ff(emb)
+ return emb
+
+ def forward(self, emb, pos_emb, t_emb):
+ gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
+ emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
+ emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
+ return emb
+
+
+class AAADiT(torch.nn.Module):
+ def __init__(self, dim=1024):
+ super().__init__()
+ self.pos_embedder = AAAPositionalEmbedding(dim=dim)
+ self.timestep_embedder = TimestepEmbeddings(256, dim)
+ self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
+ self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
+ self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
+ self.proj_out = torch.nn.Linear(dim, 128)
+
+ def forward(
+ self,
+ latents,
+ prompt_embeds,
+ timestep,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ ):
+ pos_emb = self.pos_embedder(latents, prompt_embeds)
+ t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
+ image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
+ text = self.text_embedder(prompt_embeds)
+ emb = torch.concat([image, text], dim=1)
+ for block_id, block in enumerate(self.blocks):
+ emb = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ emb=emb,
+ pos_emb=pos_emb,
+ t_emb=t_emb,
+ )
+ emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
+ emb = self.proj_out(emb)
+ emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
+ return emb
+
+
+class AAAImagePipeline(BasePipeline):
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
+ super().__init__(
+ device=device, torch_dtype=torch_dtype,
+ height_division_factor=16, width_division_factor=16,
+ )
+ self.scheduler = FlowMatchScheduler("FLUX.2")
+ self.text_encoder: ZImageTextEncoder = None
+ self.dit: AAADiT = None
+ self.vae: Flux2VAE = None
+ self.tokenizer: AutoProcessor = None
+ self.in_iteration_models = ("dit",)
+ self.units = [
+ AAAUnit_PromptEmbedder(),
+ AAAUnit_NoiseInitializer(),
+ AAAUnit_InputImageEmbedder(),
+ ]
+ self.model_fn = model_fn_aaa
+
+ @staticmethod
+ def from_pretrained(
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: Union[str, torch.device] = "cuda",
+ model_configs: list[ModelConfig] = [],
+ tokenizer_config: ModelConfig = None,
+ vram_limit: float = None,
+ ):
+ # Initialize pipeline
+ pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
+ model_pool = pipe.download_and_load_models(model_configs, vram_limit)
+
+ # Fetch models
+ pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
+ pipe.dit = model_pool.fetch_model("aaa_dit")
+ pipe.vae = model_pool.fetch_model("flux2_vae")
+ if tokenizer_config is not None:
+ tokenizer_config.download_if_necessary()
+ pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
+
+ # VRAM Management
+ pipe.vram_management_enabled = pipe.check_vram_management_state()
+ return pipe
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ # Prompt
+ prompt: str,
+ negative_prompt: str = "",
+ cfg_scale: float = 1.0,
+ # Image
+ input_image: Image.Image = None,
+ denoising_strength: float = 1.0,
+ # Shape
+ height: int = 1024,
+ width: int = 1024,
+ # Randomness
+ seed: int = None,
+ rand_device: str = "cpu",
+ # Steps
+ num_inference_steps: int = 30,
+ # Progress bar
+ progress_bar_cmd = tqdm,
+ ):
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
+
+ # Parameters
+ inputs_posi = {"prompt": prompt}
+ inputs_nega = {"negative_prompt": negative_prompt}
+ inputs_shared = {
+ "cfg_scale": cfg_scale,
+ "input_image": input_image, "denoising_strength": denoising_strength,
+ "height": height, "width": width,
+ "seed": seed, "rand_device": rand_device,
+ "num_inference_steps": num_inference_steps,
+ }
+ for unit in self.units:
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+
+ # Denoise
+ self.load_models_to_device(self.in_iteration_models)
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+ noise_pred = self.cfg_guided_model_fn(
+ self.model_fn, cfg_scale,
+ inputs_shared, inputs_posi, inputs_nega,
+ **models, timestep=timestep, progress_id=progress_id
+ )
+ inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ image = self.vae.decode(inputs_shared["latents"])
+ image = self.vae_output_to_image(image)
+ self.load_models_to_device([])
+
+ return image
+
+
+class AAAUnit_PromptEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={"prompt": "prompt"},
+ input_params_nega={"prompt": "negative_prompt"},
+ output_params=("prompt_embeds",),
+ onload_model_names=("text_encoder",)
+ )
+ self.hidden_states_layers = (-1,)
+
+ def process(self, pipe: AAAImagePipeline, prompt):
+ pipe.load_models_to_device(self.onload_model_names)
+ text = pipe.tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt}],
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=False,
+ )
+ inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
+ output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
+ prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
+ return {"prompt_embeds": prompt_embeds}
+
+
+class AAAUnit_NoiseInitializer(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("height", "width", "seed", "rand_device"),
+ output_params=("noise",),
+ )
+
+ def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
+ noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
+ return {"noise": noise}
+
+
+class AAAUnit_InputImageEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("input_image", "noise"),
+ output_params=("latents", "input_latents"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: AAAImagePipeline, input_image, noise):
+ if input_image is None:
+ return {"latents": noise, "input_latents": None}
+ pipe.load_models_to_device(['vae'])
+ image = pipe.preprocess_image(input_image)
+ input_latents = pipe.vae.encode(image)
+ if pipe.scheduler.training:
+ return {"latents": noise, "input_latents": input_latents}
+ else:
+ latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
+ return {"latents": latents, "input_latents": input_latents}
+
+
+def model_fn_aaa(
+ dit: AAADiT,
+ latents=None,
+ prompt_embeds=None,
+ timestep=None,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ **kwargs,
+):
+ model_output = dit(
+ latents,
+ prompt_embeds,
+ timestep,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ )
+ return model_output
+
+
+class AAATrainingModule(DiffusionTrainingModule):
+ def __init__(self, device):
+ super().__init__()
+ self.pipe = AAAImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device=device,
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
+ )
+ self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
+ self.pipe.freeze_except(["dit"])
+ self.pipe.scheduler.set_timesteps(1000, training=True)
+
+ def forward(self, data):
+ inputs_posi = {"prompt": data["prompt"]}
+ inputs_nega = {"negative_prompt": ""}
+ inputs_shared = {
+ "input_image": data["image"],
+ "height": data["image"].size[1],
+ "width": data["image"].size[0],
+ "cfg_scale": 1,
+ "use_gradient_checkpointing": False,
+ "use_gradient_checkpointing_offload": False,
+ }
+ for unit in self.pipe.units:
+ inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
+ loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
+ return loss
+
+
+if __name__ == "__main__":
+ accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
+ dataset = UnifiedDataset(
+ base_path="data/images",
+ metadata_path="data/metadata_merged.csv",
+ max_data_items=10000000,
+ data_file_keys=("image",),
+ main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
+ )
+ model = AAATrainingModule(device=accelerator.device)
+ model_logger = ModelLogger(
+ "models/AAA/v1",
+ remove_prefix_in_ckpt="pipe.dit.",
+ )
+ launch_training_task(
+ accelerator, dataset, model, model_logger,
+ learning_rate=2e-4,
+ num_workers=4,
+ save_steps=50000,
+ num_epochs=999999,
+ )
\ No newline at end of file
diff --git a/docs/zh/Training/Understanding_Diffusion_models.md b/docs/zh/Training/Understanding_Diffusion_models.md
index 576edc9..7613dc8 100644
--- a/docs/zh/Training/Understanding_Diffusion_models.md
+++ b/docs/zh/Training/Understanding_Diffusion_models.md
@@ -6,7 +6,7 @@
Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像或视频内容,我们从一个数据样本 $x_0$ 的生成过程开始讲起。直观地,在完整的一轮 denoise 过程中,我们从随机高斯噪声 $x_T$ 开始,通过迭代依次得到 $x_{T-1}$、$x_{T-2}$、$x_{T-3}$、$\cdots$,在每一步中逐渐减少噪声含量,最终得到不含噪声的数据样本 $x_0$。
-(图)
+
这个过程是很直观的,但如果要理解其中的细节,我们就需要回答这几个问题:
@@ -28,7 +28,7 @@ Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像
那么在中间的某一步,我们可以直接合成含噪声的数据样本 $x_t=(1-\sigma_t)x_0+\sigma_t x_T$。
-(图)
+
## 迭代去噪的计算是如何进行的?
@@ -40,8 +40,6 @@ Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像
其中,引导条件 $c$ 是新引入的参数,它是由用户输入的,可以是用于描述图像内容的文本,也可以是用于勾勒图像结构的线稿图。
-(图)
-
而模型的输出 $\hat \epsilon(x_t,c,t)$,则近似地等于 $x_T-x_0$,也就是整个扩散过程(去噪过程的反向过程)的方向。
接下来我们分析一步迭代中发生的计算,在时间步 $t$,模型通过计算得到近似的 $x_T-x_0$ 后,我们计算下一步的 $x_{t-1}$:
@@ -89,8 +87,6 @@ $$
训练过程不同于生成过程,如果我们在训练过程中保留多步迭代,那么梯度需经过多步回传,带来的时间和空间复杂度是灾难性的。为了提高计算效率,我们在训练中随机选择某一时间步 $t$ 进行训练。
-(图)
-
以下是训练过程的伪代码
> 从数据集获取数据样本 $x_0$ 和引导条件 $c$
@@ -111,7 +107,7 @@ $$
从理论到实践,还需要填充更多细节。现代 Diffusion 模型架构已经发展成熟,主流的架构沿用了 Latent Diffusion 所提出的“三段式”架构,包括数据编解码器、引导条件编码器、去噪模型三部分。
-(图)
+
### 数据编解码器
diff --git a/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py b/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py
new file mode 100644
index 0000000..098a77c
--- /dev/null
+++ b/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py
@@ -0,0 +1,53 @@
+from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, FlowMatchScheduler
+from modelscope import dataset_snapshot_download
+from PIL import Image
+import torch
+
+pipe = QwenImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
+)
+
+lora = ModelConfig(
+ model_id="lightx2v/Qwen-Image-Edit-2511-Lightning",
+ origin_file_pattern="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors"
+)
+pipe.load_lora(pipe.dit, lora, alpha=8/64)
+pipe.scheduler = FlowMatchScheduler("Qwen-Image-Lightning")
+
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/example_image_dataset",
+ allow_file_pattern="qwen_image_edit/*",
+ local_dir="data/example_image_dataset",
+)
+
+prompt = "生成这两个人的合影"
+edit_image = [
+ Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"),
+ Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"),
+]
+image = pipe(
+ prompt,
+ edit_image=edit_image,
+ seed=1,
+ num_inference_steps=4,
+ height=1152,
+ width=896,
+ edit_image_auto_resize=True,
+ zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511
+ cfg_scale=1.0,
+)
+image.save("image.jpg")
+
+# Qwen-Image-Edit-2511 is a multi-image editing model.
+# Please use a list to input `edit_image`, even if the input contains only one image.
+# edit_image = [Image.open("image.jpg")]
+# Please do not input the image directly.
+# edit_image = Image.open("image.jpg")
diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py
new file mode 100644
index 0000000..cbe43a2
--- /dev/null
+++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py
@@ -0,0 +1,63 @@
+from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, FlowMatchScheduler
+from modelscope import dataset_snapshot_download
+from PIL import Image
+import torch
+
+vram_config = {
+ "offload_dtype": "disk",
+ "offload_device": "disk",
+ "onload_dtype": torch.float8_e4m3fn,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.float8_e4m3fn,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+pipe = QwenImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
+)
+
+lora = ModelConfig(
+ model_id="lightx2v/Qwen-Image-Edit-2511-Lightning",
+ origin_file_pattern="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors"
+)
+pipe.load_lora(pipe.dit, lora, alpha=8/64)
+pipe.scheduler = FlowMatchScheduler("Qwen-Image-Lightning")
+
+
+dataset_snapshot_download(
+ "DiffSynth-Studio/example_image_dataset",
+ allow_file_pattern="qwen_image_edit/*",
+ local_dir="data/example_image_dataset",
+)
+
+prompt = "生成这两个人的合影"
+edit_image = [
+ Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"),
+ Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"),
+]
+image = pipe(
+ prompt,
+ edit_image=edit_image,
+ seed=1,
+ num_inference_steps=4,
+ height=1152,
+ width=896,
+ edit_image_auto_resize=True,
+ zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511
+ cfg_scale=1.0,
+)
+image.save("image.jpg")
+
+# Qwen-Image-Edit-2511 is a multi-image editing model.
+# Please use a list to input `edit_image`, even if the input contains only one image.
+# edit_image = [Image.open("image.jpg")]
+# Please do not input the image directly.
+# edit_image = Image.open("image.jpg")
diff --git a/examples/z_image/model_inference/Z-Image-i2L.py b/examples/z_image/model_inference/Z-Image-i2L.py
new file mode 100644
index 0000000..82b7ace
--- /dev/null
+++ b/examples/z_image/model_inference/Z-Image-i2L.py
@@ -0,0 +1,61 @@
+from diffsynth.pipelines.z_image import (
+ ZImagePipeline, ModelConfig,
+ ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
+)
+from modelscope import snapshot_download
+from safetensors.torch import save_file
+import torch
+from PIL import Image
+
+# Use `vram_config` to enable LoRA hot-loading
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cuda",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cuda",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+# Load models
+pipe = ZImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors", **vram_config),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"),
+ ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"),
+ ModelConfig(model_id="DiffSynth-Studio/Z-Image-i2L", origin_file_pattern="model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
+)
+
+# Load images
+snapshot_download(
+ model_id="DiffSynth-Studio/Z-Image-i2L",
+ allow_file_pattern="assets/style/*",
+ local_dir="data/Z-Image-i2L_style_input"
+)
+images = [Image.open(f"data/Z-Image-i2L_style_input/assets/style/1/{i}.jpg") for i in range(4)]
+
+# Image to LoRA
+with torch.no_grad():
+ embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
+ lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
+save_file(lora, "lora.safetensors")
+
+# Generate images
+prompt = "a cat"
+negative_prompt = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符"
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ seed=0, cfg_scale=4, num_inference_steps=50,
+ positive_only_lora=lora,
+ sigma_shift=8
+)
+image.save("image.jpg")
diff --git a/examples/z_image/model_inference/Z-Image.py b/examples/z_image/model_inference/Z-Image.py
new file mode 100644
index 0000000..6dca342
--- /dev/null
+++ b/examples/z_image/model_inference/Z-Image.py
@@ -0,0 +1,17 @@
+from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
+import torch
+
+
+pipe = ZImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
+)
+prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
+image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
+image.save("image_Z-Image.jpg")
diff --git a/examples/z_image/model_inference_low_vram/Z-Image-i2L.py b/examples/z_image/model_inference_low_vram/Z-Image-i2L.py
new file mode 100644
index 0000000..98b3ba3
--- /dev/null
+++ b/examples/z_image/model_inference_low_vram/Z-Image-i2L.py
@@ -0,0 +1,61 @@
+from diffsynth.pipelines.z_image import (
+ ZImagePipeline, ModelConfig,
+ ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
+)
+from modelscope import snapshot_download
+from safetensors.torch import save_file
+import torch
+from PIL import Image
+
+# Use `vram_config` to enable LoRA hot-loading
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+# Load models
+pipe = ZImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors", **vram_config),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config),
+ ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config),
+ ModelConfig(model_id="DiffSynth-Studio/Z-Image-i2L", origin_file_pattern="model.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
+)
+
+# Load images
+snapshot_download(
+ model_id="DiffSynth-Studio/Z-Image-i2L",
+ allow_file_pattern="assets/style/*",
+ local_dir="data/Z-Image-i2L_style_input"
+)
+images = [Image.open(f"data/Z-Image-i2L_style_input/assets/style/1/{i}.jpg") for i in range(4)]
+
+# Image to LoRA
+with torch.no_grad():
+ embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
+ lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
+save_file(lora, "lora.safetensors")
+
+# Generate images
+prompt = "a cat"
+negative_prompt = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符"
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ seed=0, cfg_scale=4, num_inference_steps=50,
+ positive_only_lora=lora,
+ sigma_shift=8
+)
+image.save("image.jpg")
diff --git a/examples/z_image/model_inference_low_vram/Z-Image.py b/examples/z_image/model_inference_low_vram/Z-Image.py
new file mode 100644
index 0000000..344ae50
--- /dev/null
+++ b/examples/z_image/model_inference_low_vram/Z-Image.py
@@ -0,0 +1,26 @@
+from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
+import torch
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+pipe = ZImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors", **vram_config),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
+)
+prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
+image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
+image.save("image_Z-Image.jpg")
diff --git a/examples/z_image/model_training/full/Z-Image.sh b/examples/z_image/model_training/full/Z-Image.sh
new file mode 100644
index 0000000..2136324
--- /dev/null
+++ b/examples/z_image/model_training/full/Z-Image.sh
@@ -0,0 +1,14 @@
+# This example is tested on 8*A100
+accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \
+ --dataset_base_path data/example_image_dataset \
+ --dataset_metadata_path data/example_image_dataset/metadata.csv \
+ --max_pixels 1048576 \
+ --dataset_repeat 400 \
+ --model_id_with_origin_paths "Tongyi-MAI/Z-Image:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --output_path "./models/train/Z-Image_full" \
+ --trainable_models "dit" \
+ --use_gradient_checkpointing \
+ --dataset_num_workers 8
diff --git a/examples/z_image/model_training/lora/Z-Image.sh b/examples/z_image/model_training/lora/Z-Image.sh
new file mode 100644
index 0000000..b660eef
--- /dev/null
+++ b/examples/z_image/model_training/lora/Z-Image.sh
@@ -0,0 +1,15 @@
+accelerate launch examples/z_image/model_training/train.py \
+ --dataset_base_path data/example_image_dataset \
+ --dataset_metadata_path data/example_image_dataset/metadata.csv \
+ --max_pixels 1048576 \
+ --dataset_repeat 50 \
+ --model_id_with_origin_paths "Tongyi-MAI/Z-Image:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
+ --learning_rate 1e-4 \
+ --num_epochs 5 \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --output_path "./models/train/Z-Image_lora" \
+ --lora_base_model "dit" \
+ --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
+ --lora_rank 32 \
+ --use_gradient_checkpointing \
+ --dataset_num_workers 8
diff --git a/examples/z_image/model_training/validate_full/Z-Image.py b/examples/z_image/model_training/validate_full/Z-Image.py
new file mode 100644
index 0000000..b2a1d8e
--- /dev/null
+++ b/examples/z_image/model_training/validate_full/Z-Image.py
@@ -0,0 +1,20 @@
+from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
+from diffsynth.core import load_state_dict
+import torch
+
+
+pipe = ZImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
+)
+state_dict = load_state_dict("./models/train/Z-Image_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
+pipe.dit.load_state_dict(state_dict)
+prompt = "a dog"
+image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
+image.save("image.jpg")
diff --git a/examples/z_image/model_training/validate_lora/Z-Image.py b/examples/z_image/model_training/validate_lora/Z-Image.py
new file mode 100644
index 0000000..d12356f
--- /dev/null
+++ b/examples/z_image/model_training/validate_lora/Z-Image.py
@@ -0,0 +1,18 @@
+from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
+import torch
+
+
+pipe = ZImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
+)
+pipe.load_lora(pipe.dit, "./models/train/Z-Image_lora/epoch-4.safetensors")
+prompt = "a dog"
+image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
+image.save("image.jpg")
diff --git a/pyproject.toml b/pyproject.toml
index de82279..9a5075b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "diffsynth"
-version = "2.0.3"
+version = "2.0.4"
description = "Enjoy the magic of Diffusion models!"
authors = [{name = "ModelScope Team"}]
license = {text = "Apache-2.0"}