From 544c391936b6b9c301b99b070996f97a57217871 Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Mon, 12 Jan 2026 11:24:11 +0800 Subject: [PATCH 01/16] [model][NPU]:Wan model rope use torch.complex64 in NPU --- docs/en/Pipeline_Usage/GPU_support.md | 2 +- docs/zh/Pipeline_Usage/GPU_support.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/en/Pipeline_Usage/GPU_support.md b/docs/en/Pipeline_Usage/GPU_support.md index 6c27de7..aba5706 100644 --- a/docs/en/Pipeline_Usage/GPU_support.md +++ b/docs/en/Pipeline_Usage/GPU_support.md @@ -59,7 +59,7 @@ save_video(video, "video.mp4", fps=15, quality=5) ``` ### Training -NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_scripts`, for example `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`. +NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`. In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models. diff --git a/docs/zh/Pipeline_Usage/GPU_support.md b/docs/zh/Pipeline_Usage/GPU_support.md index b955f56..8124147 100644 --- a/docs/zh/Pipeline_Usage/GPU_support.md +++ b/docs/zh/Pipeline_Usage/GPU_support.md @@ -59,7 +59,7 @@ save_video(video, "video.mp4", fps=15, quality=5) ``` ### 训练 -当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_scripts`目录下,例如 `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`。 +当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`。 在NPU训练脚本中,添加了可以优化性能的NPU特有环境变量,并针对特定模型开启了相关参数。 From 6be244233a5706d0cf7e0fc8f019566f8f0dca8f Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Mon, 12 Jan 2026 11:34:41 +0800 Subject: [PATCH 02/16] [model][NPU]:Wan model rope use torch.complex64 in NPU --- diffsynth/core/device/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/core/device/__init__.py b/diffsynth/core/device/__init__.py index 8373471..889d682 100644 --- a/diffsynth/core/device/__init__.py +++ b/diffsynth/core/device/__init__.py @@ -1,2 +1,2 @@ from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name -from .npu_compatible_device import IS_NPU_AVAILABLE +from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE From 03e530dc39ff47f05ddfc51a9ac3b613f08568c0 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 12 Jan 2026 17:20:01 +0800 Subject: [PATCH 03/16] support qwen-image-layered-control --- README.md | 5 +++ README_zh.md | 5 +++ docs/en/Model_Details/Qwen-Image.md | 1 + docs/zh/Model_Details/Qwen-Image.md | 1 + .../Qwen-Image-Layered-Control.py | 34 ++++++++++++++ .../Qwen-Image-Layered-Control.py | 44 +++++++++++++++++++ .../full/Qwen-Image-Layered-Control.sh | 18 ++++++++ .../lora/Qwen-Image-Layered-Control.sh | 20 +++++++++ .../Qwen-Image-Layered-Control.py | 26 +++++++++++ .../Qwen-Image-Layered-Control.py | 25 +++++++++++ 10 files changed, 179 insertions(+) create mode 100644 examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py create mode 100644 examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py create mode 100644 examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh create mode 100644 examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh create mode 100644 examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py create mode 100644 examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py diff --git a/README.md b/README.md index 2e4a41a..69f70e1 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,10 @@ We believe that a well-developed open-source code framework can lower the thresh > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. +- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. + +- **December 24, 2025**: Based on Qwen-Image-Edit-2511, we trained an In-Context Editing LoRA model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)). This model takes three images as input (Image A, Image B, and Image C), and automatically analyzes the transformation from Image A to Image B, then applies the same transformation to Image C to generate Image D. For more details, please refer to our blog post ([Chinese version](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g), [English version](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)). + - **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. For more details, please refer to our [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l). - **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online @@ -401,6 +405,7 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| |[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| +|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| diff --git a/README_zh.md b/README_zh.md index 8deb8bf..250c48e 100644 --- a/README_zh.md +++ b/README_zh.md @@ -33,6 +33,10 @@ DiffSynth 目前包括两个开源项目: > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 +- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。 + +- **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)),这个模型可以输入三张图:图A、图B、图C,模型会自行分析图A到图B的变化,并将这样的变化应用到图C,生成图D。更多细节请阅读我们的 blog([中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。 + - **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)(Image to LoRA)。这一模型以图像为输入,以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。 - **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线 @@ -401,6 +405,7 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/ |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| |[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| +|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| diff --git a/docs/en/Model_Details/Qwen-Image.md b/docs/en/Model_Details/Qwen-Image.md index 3a7c1e6..08b8a35 100644 --- a/docs/en/Model_Details/Qwen-Image.md +++ b/docs/en/Model_Details/Qwen-Image.md @@ -86,6 +86,7 @@ graph LR; | [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) | |[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| |[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| +|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)| | [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | | [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | | [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) | diff --git a/docs/zh/Model_Details/Qwen-Image.md b/docs/zh/Model_Details/Qwen-Image.md index 8fe9de4..697438f 100644 --- a/docs/zh/Model_Details/Qwen-Image.md +++ b/docs/zh/Model_Details/Qwen-Image.md @@ -86,6 +86,7 @@ graph LR; |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| |[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| +|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| diff --git a/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py b/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py new file mode 100644 index 0000000..5ce82c7 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py @@ -0,0 +1,34 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from modelscope import snapshot_download +from PIL import Image +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Layered-Control", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) + +snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-Layered-Control", + allow_file_pattern="assets/image_1_input.png", + local_dir="data/layered_input" +) + +prompt = "A cartoon skeleton character wearing a purple hat and holding a gift box" +input_image = Image.open("data/layered_input/assets/image_1_input.png").convert("RGBA").resize((1024, 1024)) +images = pipe( + prompt, + seed=0, + num_inference_steps=30, cfg_scale=4, + height=1024, width=1024, + layer_input_image=input_image, + layer_num=0, +) +images[0].save("image.png") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py new file mode 100644 index 0000000..eb5c77d --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py @@ -0,0 +1,44 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from modelscope import snapshot_download +from PIL import Image +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Layered-Control", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) + +snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-Layered-Control", + allow_file_pattern="assets/image_1_input.png", + local_dir="data/layered_input" +) + +prompt = "A cartoon skeleton character wearing a purple hat and holding a gift box" +input_image = Image.open("data/layered_input/assets/image_1_input.png").convert("RGBA").resize((1024, 1024)) +images = pipe( + prompt, + seed=0, + num_inference_steps=30, cfg_scale=4, + height=1024, width=1024, + layer_input_image=input_image, + layer_num=0, +) +images[0].save("image.png") diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh b/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh new file mode 100644 index 0000000..14a3cb4 --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh @@ -0,0 +1,18 @@ +# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer + +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset/layer \ + --dataset_metadata_path data/example_image_dataset/layer/metadata_layered_control.json \ + --data_file_keys "image,layer_input_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "DiffSynth-Studio/Qwen-Image-Layered-Control:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Layered-Control_full" \ + --trainable_models "dit" \ + --extra_inputs "layer_num,layer_input_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh new file mode 100644 index 0000000..397c975 --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh @@ -0,0 +1,20 @@ +# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer + +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset/layer \ + --dataset_metadata_path data/example_image_dataset/layer/metadata_layered_control.json \ + --data_file_keys "image,layer_input_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "DiffSynth-Studio/Qwen-Image-Layered-Control:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Layered-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --extra_inputs "layer_num,layer_input_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py b/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py new file mode 100644 index 0000000..961904f --- /dev/null +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py @@ -0,0 +1,26 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Layered-Control", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("models/train/Qwen-Image-Layered-Control_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +prompt = "Text 'HELLO' and 'Have a great day'" +input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480)) +images = pipe( + prompt, seed=0, + height=480, width=864, + layer_input_image=input_image, layer_num=0, +) +images[0].save("image.png") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py new file mode 100644 index 0000000..1a96e8b --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py @@ -0,0 +1,25 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Layered-Control", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Layered-Control_lora/epoch-4.safetensors") +prompt = "Text 'HELLO' and 'Have a great day'" +input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480)) +images = pipe( + prompt, seed=0, + height=480, width=864, + layer_input_image=input_image, layer_num=0, +) +images[0].save("image.png") From e99cdcf3b8fdccd8c5b9697cef041a9b92569261 Mon Sep 17 00:00:00 2001 From: lzws <63908509+lzws@users.noreply.github.com> Date: Mon, 12 Jan 2026 22:08:48 +0800 Subject: [PATCH 04/16] wan usp bug fix --- diffsynth/pipelines/wan_video.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index ca59d2a..45ea43c 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -122,11 +122,15 @@ class WanVideoPipeline(BasePipeline): model_config.model_id = redirect_dict[model_config.origin_file_pattern][0] model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1] - # Initialize pipeline - pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) if use_usp: from ..utils.xfuser import initialize_usp initialize_usp(device) + import torch.distributed as dist + from ..core.device.npu_compatible_device import get_device_name, IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE + if dist.is_available() and dist.is_initialized() and (IS_CUDA_AVAILABLE or IS_NPU_AVAILABLE): + device = get_device_name() + # Initialize pipeline + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) model_pool = pipe.download_and_load_models(model_configs, vram_limit) # Fetch models From d16877e69548523f2ea23c4fff530bdd81b31cfa Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Tue, 13 Jan 2026 11:17:51 +0800 Subject: [PATCH 05/16] [model][NPU]:Wan model rope use torch.complex64 in NPU --- diffsynth/models/wan_video_dit.py | 3 +-- diffsynth/utils/xfuser/xdit_context_parallel.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 43cd601..7386223 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -5,7 +5,6 @@ import math from typing import Tuple, Optional from einops import rearrange from .wan_video_camera_controller import SimpleAdapter -from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE try: import flash_attn_interface @@ -94,7 +93,7 @@ def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) - freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs + freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index d365cfe..21dc3b3 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -5,7 +5,7 @@ from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention -from ...core.device import parse_nccl_backend, parse_device_type, IS_NPU_AVAILABLE +from ...core.device import parse_nccl_backend, parse_device_type def initialize_usp(device_type): @@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads): sp_rank = get_sequence_parallel_rank() freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] - freqs_rank = freqs_rank.to(torch.complex64) if IS_NPU_AVAILABLE else freqs_rank + freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype) From acba342a630f6053d890359aa91b7b54d40a38dc Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 14 Jan 2026 16:29:43 +0800 Subject: [PATCH 06/16] fix RMSNorm precision --- diffsynth/models/z_image_dit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index f157f38..6e8866a 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence -from torch.nn import RMSNorm +from .general_modules import RMSNorm from ..core.attention import attention_forward from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE from ..core.gradient import gradient_checkpoint_forward From fd87b727541ffaff186ab3ad326f1787af9350c2 Mon Sep 17 00:00:00 2001 From: lzws <63908509+lzws@users.noreply.github.com> Date: Wed, 14 Jan 2026 16:33:02 +0800 Subject: [PATCH 07/16] wan usp bug fix --- diffsynth/pipelines/wan_video.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 45ea43c..5b4c0b4 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -126,8 +126,8 @@ class WanVideoPipeline(BasePipeline): from ..utils.xfuser import initialize_usp initialize_usp(device) import torch.distributed as dist - from ..core.device.npu_compatible_device import get_device_name, IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE - if dist.is_available() and dist.is_initialized() and (IS_CUDA_AVAILABLE or IS_NPU_AVAILABLE): + from ..core.device.npu_compatible_device import get_device_name + if dist.is_available() and dist.is_initialized(): device = get_device_name() # Initialize pipeline pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) From c90aaa2798885aa6ae1720bc9075066d50054963 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 14 Jan 2026 20:49:36 +0800 Subject: [PATCH 08/16] fix flux compatibility issues --- .github/workflows/publish.yaml | 2 +- diffsynth/configs/model_configs.py | 7 +++++++ diffsynth/utils/state_dict_converters/flux_dit.py | 2 ++ pyproject.toml | 2 +- 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index f31e6bb..31e8947 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -22,7 +22,7 @@ jobs: - name: Install wheel run: pip install wheel==0.44.0 && pip install -r requirements.txt - name: Build DiffSynth - run: python setup.py sdist bdist_wheel + run: python -m build - name: Publish package to PyPI run: | pip install twine diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 7da7a9d..eed58f8 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -481,6 +481,13 @@ flux_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", "extra_kwargs": {"disable_guidance_embedder": True}, }, + { + # Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors") + "model_hash": "3394f306c4cbf04334b712bf5aaed95f", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, ] flux2_series = [ diff --git a/diffsynth/utils/state_dict_converters/flux_dit.py b/diffsynth/utils/state_dict_converters/flux_dit.py index 8469c87..f808b60 100644 --- a/diffsynth/utils/state_dict_converters/flux_dit.py +++ b/diffsynth/utils/state_dict_converters/flux_dit.py @@ -143,6 +143,8 @@ def FluxDiTStateDictConverterFromDiffusers(state_dict): suffix = ".weight" if name.endswith(".weight") else ".bias" prefix = name[:-len(suffix)] if prefix in global_rename_dict: + if global_rename_dict[prefix] == "final_norm_out.linear": + param = torch.concat([param[3072:], param[:3072]], dim=0) state_dict_[global_rename_dict[prefix] + suffix] = param elif prefix.startswith("transformer_blocks."): names = prefix.split(".") diff --git a/pyproject.toml b/pyproject.toml index 04b9f70..059e21d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "diffsynth" -version = "2.0.1" +version = "2.0.2" description = "Enjoy the magic of Diffusion models!" authors = [{name = "ModelScope Team"}] license = {text = "Apache-2.0"} From 55e8346da3fd725e12ca9f3251eb79dd75469a25 Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Thu, 15 Jan 2026 12:31:55 +0800 Subject: [PATCH 09/16] Blog link (#1202) * update README --- README.md | 2 +- README_zh.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 69f70e1..2c7ef27 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ We believe that a well-developed open-source code framework can lower the thresh > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. -- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. +- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)). - **December 24, 2025**: Based on Qwen-Image-Edit-2511, we trained an In-Context Editing LoRA model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)). This model takes three images as input (Image A, Image B, and Image C), and automatically analyzes the transformation from Image A to Image B, then applies the same transformation to Image C to generate Image D. For more details, please refer to our blog post ([Chinese version](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g), [English version](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)). diff --git a/README_zh.md b/README_zh.md index 250c48e..81a33e3 100644 --- a/README_zh.md +++ b/README_zh.md @@ -33,7 +33,7 @@ DiffSynth 目前包括两个开源项目: > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 -- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。 +- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog([中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。 - **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)),这个模型可以输入三张图:图A、图B、图C,模型会自行分析图A到图B的变化,并将这样的变化应用到图C,生成图D。更多细节请阅读我们的 blog([中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。 From ae52d9369468889a5f536254a9b4131321cff720 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 16 Jan 2026 13:09:41 +0800 Subject: [PATCH 10/16] support klein 4b models --- diffsynth/configs/model_configs.py | 7 ++++ diffsynth/models/flux2_dit.py | 67 +++++++++++++----------------- diffsynth/pipelines/flux2_image.py | 5 ++- 3 files changed, 39 insertions(+), 40 deletions(-) diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index eed58f8..cc23fb9 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -510,6 +510,13 @@ flux2_series = [ "model_name": "flux2_vae", "model_class": "diffsynth.models.flux2_vae.Flux2VAE", }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors") + "model_hash": "3bde7b817fec8143028b6825a63180df", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20} + }, ] z_image_series = [ diff --git a/diffsynth/models/flux2_dit.py b/diffsynth/models/flux2_dit.py index a08c579..316cf08 100644 --- a/diffsynth/models/flux2_dit.py +++ b/diffsynth/models/flux2_dit.py @@ -823,7 +823,13 @@ class Flux2PosEmbed(nn.Module): class Flux2TimestepGuidanceEmbeddings(nn.Module): - def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False): + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, + guidance_embeds: bool = True, + ): super().__init__() self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) @@ -831,20 +837,24 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module): in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias ) - self.guidance_embedder = TimestepEmbedding( - in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias - ) + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + else: + self.guidance_embedder = None def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) - guidance_proj = self.time_proj(guidance) - guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) - - time_guidance_emb = timesteps_emb + guidance_emb - - return time_guidance_emb + if guidance is not None and self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) + time_guidance_emb = timesteps_emb + guidance_emb + return time_guidance_emb + else: + return timesteps_emb class Flux2Modulation(nn.Module): @@ -882,6 +892,7 @@ class Flux2DiT(torch.nn.Module): axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), rope_theta: int = 2000, eps: float = 1e-6, + guidance_embeds: bool = True, ): super().__init__() self.out_channels = out_channels or in_channels @@ -892,7 +903,10 @@ class Flux2DiT(torch.nn.Module): # 2. Combined timestep + guidance embedding self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( - in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False + in_channels=timestep_guidance_channels, + embedding_dim=self.inner_dim, + bias=False, + guidance_embeds=guidance_embeds, ) # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) @@ -953,34 +967,9 @@ class Flux2DiT(torch.nn.Module): txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, - ) -> Union[torch.Tensor]: - """ - The [`FluxTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): - Input `hidden_states`. - encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ + ): # 0. Handle input arguments if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() @@ -992,7 +981,9 @@ class Flux2DiT(torch.nn.Module): # 1. Calculate timestep embedding and modulation parameters timestep = timestep.to(hidden_states.dtype) * 1000 - guidance = guidance.to(hidden_states.dtype) * 1000 + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 temb = self.time_guidance_embed(timestep, guidance) diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 8b00469..e94d2c3 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -10,7 +10,7 @@ from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput -from transformers import AutoProcessor +from transformers import AutoProcessor, AutoTokenizer from ..models.flux2_text_encoder import Flux2TextEncoder from ..models.flux2_dit import Flux2DiT from ..models.flux2_vae import Flux2VAE @@ -53,11 +53,12 @@ class Flux2ImagePipeline(BasePipeline): # Fetch models pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder") + pipe.text_encoder_qwen3 = model_pool.fetch_model("z_image_text_encoder") pipe.dit = model_pool.fetch_model("flux2_dit") pipe.vae = model_pool.fetch_model("flux2_vae") if tokenizer_config is not None: tokenizer_config.download_if_necessary() - pipe.tokenizer = AutoProcessor.from_pretrained(tokenizer_config.path) + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) # VRAM Management pipe.vram_management_enabled = pipe.check_vram_management_state() From b6ccb362b9e9ee3a14303b82494ae1d6e14e989f Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 19 Jan 2026 16:56:14 +0800 Subject: [PATCH 11/16] support flux.2 klein --- diffsynth/configs/model_configs.py | 15 ++ diffsynth/diffusion/training_module.py | 20 ++- diffsynth/models/z_image_text_encoder.py | 95 ++++++++---- diffsynth/pipelines/flux2_image.py | 137 ++++++++++++++++++ .../z_image_text_encoder.py | 6 + .../flux2/model_inference/FLUX.2-klein-4B.py | 17 +++ .../flux2/model_inference/FLUX.2-klein-9B.py | 17 +++ .../FLUX.2-klein-4B.py | 27 ++++ .../FLUX.2-klein-9B.py | 27 ++++ .../model_training/full/FLUX.2-klein-4B.sh | 13 ++ .../model_training/full/FLUX.2-klein-9B.sh | 13 ++ .../model_training/lora/FLUX.2-klein-4B.sh | 15 ++ .../model_training/lora/FLUX.2-klein-9B.sh | 15 ++ examples/flux2/model_training/train.py | 2 +- .../validate_full/FLUX.2-klein-4B.py | 20 +++ .../validate_full/FLUX.2-klein-9B.py | 20 +++ .../validate_lora/FLUX.2-klein-4B.py | 18 +++ .../validate_lora/FLUX.2-klein-9B.py | 18 +++ 18 files changed, 460 insertions(+), 35 deletions(-) create mode 100644 diffsynth/utils/state_dict_converters/z_image_text_encoder.py create mode 100644 examples/flux2/model_inference/FLUX.2-klein-4B.py create mode 100644 examples/flux2/model_inference/FLUX.2-klein-9B.py create mode 100644 examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py create mode 100644 examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py create mode 100644 examples/flux2/model_training/full/FLUX.2-klein-4B.sh create mode 100644 examples/flux2/model_training/full/FLUX.2-klein-9B.sh create mode 100644 examples/flux2/model_training/lora/FLUX.2-klein-4B.sh create mode 100644 examples/flux2/model_training/lora/FLUX.2-klein-9B.sh create mode 100644 examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py create mode 100644 examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py create mode 100644 examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py create mode 100644 examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index cc23fb9..c93f5e9 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -517,6 +517,21 @@ flux2_series = [ "model_class": "diffsynth.models.flux2_dit.Flux2DiT", "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20} }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors") + "model_hash": "9195f3ea256fcd0ae6d929c203470754", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + "extra_kwargs": {"model_size": "8B"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors") + "model_hash": "39c6fc48f07bebecedbbaa971ff466c8", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24} + }, ] z_image_series = [ diff --git a/diffsynth/diffusion/training_module.py b/diffsynth/diffusion/training_module.py index e3b3329..b658866 100644 --- a/diffsynth/diffusion/training_module.py +++ b/diffsynth/diffusion/training_module.py @@ -1,4 +1,4 @@ -import torch, json +import torch, json, os from ..core import ModelConfig, load_state_dict from ..utils.controlnet import ControlNetInput from peft import LoraConfig, inject_adapter_in_model @@ -127,15 +127,29 @@ class DiffusionTrainingModule(torch.nn.Module): if model_id_with_origin_paths is not None: model_id_with_origin_paths = model_id_with_origin_paths.split(",") for model_id_with_origin_path in model_id_with_origin_paths: - model_id, origin_file_pattern = model_id_with_origin_path.split(":") vram_config = self.parse_vram_config( fp8=model_id_with_origin_path in fp8_models, offload=model_id_with_origin_path in offload_models, device=device ) - model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config)) + config = self.parse_path_or_model_id(model_id_with_origin_path) + model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config)) return model_configs + + def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None): + if model_id_with_origin_path is None: + return default_value + elif os.path.exists(model_id_with_origin_path): + return ModelConfig(path=model_id_with_origin_path) + else: + if ":" not in model_id_with_origin_path: + raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.") + split_id = model_id_with_origin_path.rfind(":") + model_id = model_id_with_origin_path[:split_id] + origin_file_pattern = model_id_with_origin_path[split_id + 1:] + return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern) + def switch_pipe_to_training_mode( self, diff --git a/diffsynth/models/z_image_text_encoder.py b/diffsynth/models/z_image_text_encoder.py index 4eba636..4d6271d 100644 --- a/diffsynth/models/z_image_text_encoder.py +++ b/diffsynth/models/z_image_text_encoder.py @@ -3,38 +3,71 @@ import torch class ZImageTextEncoder(torch.nn.Module): - def __init__(self): + def __init__(self, model_size="4B"): super().__init__() - config = Qwen3Config(**{ - "architectures": [ - "Qwen3ForCausalLM" - ], - "attention_bias": False, - "attention_dropout": 0.0, - "bos_token_id": 151643, - "eos_token_id": 151645, - "head_dim": 128, - "hidden_act": "silu", - "hidden_size": 2560, - "initializer_range": 0.02, - "intermediate_size": 9728, - "max_position_embeddings": 40960, - "max_window_layers": 36, - "model_type": "qwen3", - "num_attention_heads": 32, - "num_hidden_layers": 36, - "num_key_value_heads": 8, - "rms_norm_eps": 1e-06, - "rope_scaling": None, - "rope_theta": 1000000, - "sliding_window": None, - "tie_word_embeddings": True, - "torch_dtype": "bfloat16", - "transformers_version": "4.51.0", - "use_cache": True, - "use_sliding_window": False, - "vocab_size": 151936 - }) + config_dict = { + "4B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 9728, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }), + "8B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": False, + "transformers_version": "4.56.1", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }) + } + config = config_dict[model_size] self.model = Qwen3Model(config) def forward(self, *args, **kwargs): diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index e94d2c3..b736625 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -14,6 +14,7 @@ from transformers import AutoProcessor, AutoTokenizer from ..models.flux2_text_encoder import Flux2TextEncoder from ..models.flux2_dit import Flux2DiT from ..models.flux2_vae import Flux2VAE +from ..models.z_image_text_encoder import ZImageTextEncoder class Flux2ImagePipeline(BasePipeline): @@ -25,6 +26,7 @@ class Flux2ImagePipeline(BasePipeline): ) self.scheduler = FlowMatchScheduler("FLUX.2") self.text_encoder: Flux2TextEncoder = None + self.text_encoder_qwen3: ZImageTextEncoder = None self.dit: Flux2DiT = None self.vae: Flux2VAE = None self.tokenizer: AutoProcessor = None @@ -32,6 +34,7 @@ class Flux2ImagePipeline(BasePipeline): self.units = [ Flux2Unit_ShapeChecker(), Flux2Unit_PromptEmbedder(), + Flux2Unit_Qwen3PromptEmbedder(), Flux2Unit_NoiseInitializer(), Flux2Unit_InputImageEmbedder(), Flux2Unit_ImageIDs(), @@ -276,6 +279,10 @@ class Flux2Unit_PromptEmbedder(PipelineUnit): return prompt_embeds, text_ids def process(self, pipe: Flux2ImagePipeline, prompt): + # Skip if Qwen3 text encoder is available (handled by Qwen3PromptEmbedder) + if pipe.text_encoder_qwen3 is not None: + return {} + pipe.load_models_to_device(self.onload_model_names) prompt_embeds, text_ids = self.encode_prompt( pipe.text_encoder, pipe.tokenizer, prompt, @@ -284,6 +291,136 @@ class Flux2Unit_PromptEmbedder(PipelineUnit): return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} +class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder_qwen3",) + ) + self.hidden_states_layers = (9, 18, 27) # Qwen3 layers + + def get_qwen3_prompt_embeds( + self, + text_encoder: ZImageTextEncoder, + tokenizer: AutoTokenizer, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + with torch.inference_mode(): + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + return prompt_embeds + + def prepare_text_ids( + self, + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def encode_prompt( + self, + text_encoder: ZImageTextEncoder, + tokenizer: AutoTokenizer, + prompt: Union[str, List[str]], + dtype = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self.get_qwen3_prompt_embeds( + text_encoder=text_encoder, + tokenizer=tokenizer, + prompt=prompt, + dtype=dtype, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self.prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def process(self, pipe: Flux2ImagePipeline, prompt): + # Check if Qwen3 text encoder is available + if pipe.text_encoder_qwen3 is None: + return {} + + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds, text_ids = self.encode_prompt( + pipe.text_encoder_qwen3, pipe.tokenizer, prompt, + dtype=pipe.torch_dtype, device=pipe.device, + ) + return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} + + class Flux2Unit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( diff --git a/diffsynth/utils/state_dict_converters/z_image_text_encoder.py b/diffsynth/utils/state_dict_converters/z_image_text_encoder.py new file mode 100644 index 0000000..b114613 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/z_image_text_encoder.py @@ -0,0 +1,6 @@ +def ZImageTextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name != "lm_head.weight": + state_dict_[name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/examples/flux2/model_inference/FLUX.2-klein-4B.py b/examples/flux2/model_inference/FLUX.2-klein-4B.py new file mode 100644 index 0000000..fbfe33d --- /dev/null +++ b/examples/flux2/model_inference/FLUX.2-klein-4B.py @@ -0,0 +1,17 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) +image.save("image_FLUX.2-klein-4B.jpg") diff --git a/examples/flux2/model_inference/FLUX.2-klein-9B.py b/examples/flux2/model_inference/FLUX.2-klein-9B.py new file mode 100644 index 0000000..2abf0e7 --- /dev/null +++ b/examples/flux2/model_inference/FLUX.2-klein-9B.py @@ -0,0 +1,17 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) +image.save("image_FLUX.2-klein-9B.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py new file mode 100644 index 0000000..019f58e --- /dev/null +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) +image.save("image_FLUX.2-klein-4B.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py new file mode 100644 index 0000000..b629c94 --- /dev/null +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) +image.save("image_FLUX.2-klein-9B.jpg") diff --git a/examples/flux2/model_training/full/FLUX.2-klein-4B.sh b/examples/flux2/model_training/full/FLUX.2-klein-4B.sh new file mode 100644 index 0000000..4fa46da --- /dev/null +++ b/examples/flux2/model_training/full/FLUX.2-klein-4B.sh @@ -0,0 +1,13 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-4B_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/full/FLUX.2-klein-9B.sh b/examples/flux2/model_training/full/FLUX.2-klein-9B.sh new file mode 100644 index 0000000..c89e8f0 --- /dev/null +++ b/examples/flux2/model_training/full/FLUX.2-klein-9B.sh @@ -0,0 +1,13 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-9B_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh b/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh new file mode 100644 index 0000000..8f897cc --- /dev/null +++ b/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh @@ -0,0 +1,15 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-4B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh b/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh new file mode 100644 index 0000000..258c5fe --- /dev/null +++ b/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh @@ -0,0 +1,15 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-9B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index 30408a1..ea727b8 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -24,7 +24,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule): super().__init__() # Load models model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) - tokenizer_config = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) + tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/")) self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py new file mode 100644 index 0000000..c5473ab --- /dev/null +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/FLUX.2-klein-4B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py new file mode 100644 index 0000000..09ac4bc --- /dev/null +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/FLUX.2-klein-9B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") diff --git a/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py b/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py new file mode 100644 index 0000000..93fe2fa --- /dev/null +++ b/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-4B_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") diff --git a/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py b/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py new file mode 100644 index 0000000..75470bc --- /dev/null +++ b/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-9B_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") From 2336d5f6b340120aac9c14f0b6a00c6d76273531 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 19 Jan 2026 17:27:32 +0800 Subject: [PATCH 12/16] update doc --- README.md | 7 +++++-- README_zh.md | 6 +++++- docs/en/Model_Details/FLUX2.md | 27 +++++++++++++++++++-------- docs/zh/Model_Details/FLUX2.md | 27 +++++++++++++++++++-------- 4 files changed, 48 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 2c7ef27..a590402 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ We believe that a well-developed open-source code framework can lower the thresh > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. +- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available. + - **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)). - **December 24, 2025**: Based on Qwen-Image-Edit-2511, we trained an In-Context Editing LoRA model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)). This model takes three images as input (Image A, Image B, and Image C), and automatically analyzes the transformation from Image A to Image B, then applies the same transformation to Image C to generate Image D. For more details, please refer to our blog post ([Chinese version](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g), [English version](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)). @@ -321,7 +323,9 @@ Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/) | Model ID | Inference | Low-VRAM Inference | LoRA Training | LoRA Training Validation | |-|-|-|-|-| -|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| +|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| @@ -774,4 +778,3 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47 https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea - diff --git a/README_zh.md b/README_zh.md index 81a33e3..6a0e0ef 100644 --- a/README_zh.md +++ b/README_zh.md @@ -33,6 +33,8 @@ DiffSynth 目前包括两个开源项目: > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 +- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。 + - **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog([中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。 - **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)),这个模型可以输入三张图:图A、图B、图C,模型会自行分析图A到图B的变化,并将这样的变化应用到图C,生成图D。更多细节请阅读我们的 blog([中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。 @@ -321,7 +323,9 @@ FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/) |模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-| -|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| +|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| diff --git a/docs/en/Model_Details/FLUX2.md b/docs/en/Model_Details/FLUX2.md index fd5e56d..70ccf21 100644 --- a/docs/en/Model_Details/FLUX2.md +++ b/docs/en/Model_Details/FLUX2.md @@ -2,6 +2,15 @@ FLUX.2 is an image generation model trained and open-sourced by Black Forest Labs. +## Model Lineage + +```mermaid +graph LR; + FLUX.2-Series-->black-forest-labs/FLUX.2-dev; + FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B; + FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B; +``` + ## Installation Before using this project for model inference and training, please install DiffSynth-Studio first. @@ -50,16 +59,18 @@ image.save("image.jpg") ## Model Overview -| Model ID | Inference | Low VRAM Inference | LoRA Training | Validation After LoRA Training | -| - | - | - | - | - | -| [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) | [code](/examples/flux2/model_inference/FLUX.2-dev.py) | [code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py) | [code](/examples/flux2/model_training/lora/FLUX.2-dev.sh) | [code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py) | +| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | +| - | - | - | - | - | - | - | +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| +|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| Special Training Scripts: -* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/) -* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/) -* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/) -* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) +* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md) +* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md) +* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md) +* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md) ## Model Inference @@ -135,4 +146,4 @@ We have built a sample image dataset for your testing. You can download this dat modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset ``` -We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). \ No newline at end of file +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). diff --git a/docs/zh/Model_Details/FLUX2.md b/docs/zh/Model_Details/FLUX2.md index ad4df27..cf91054 100644 --- a/docs/zh/Model_Details/FLUX2.md +++ b/docs/zh/Model_Details/FLUX2.md @@ -2,6 +2,15 @@ FLUX.2 是由 Black Forest Labs 训练并开源的图像生成模型。 +## 模型血缘 + +```mermaid +graph LR; + FLUX.2-Series-->black-forest-labs/FLUX.2-dev; + FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B; + FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B; +``` + ## 安装 在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 @@ -50,16 +59,18 @@ image.save("image.jpg") ## 模型总览 -|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证| -|-|-|-|-|-| -|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| +|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| 特殊训练脚本: -* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/flux/model_training/special/differential_training/) -* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/flux/model_training/special/fp8_training/) -* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/flux/model_training/special/split_training/) -* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) +* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md) +* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md) +* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md) +* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md) ## 模型推理 @@ -135,4 +146,4 @@ FLUX.2 系列模型统一通过 [`examples/flux2/model_training/train.py`](/exam modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset ``` -我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 \ No newline at end of file +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 From a18e6233b5236ebcd91333e90f5f1d16cc5b9381 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 19 Jan 2026 17:35:08 +0800 Subject: [PATCH 13/16] updata wan-vace training scripts --- .../wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh | 5 +++-- examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh | 5 +++-- examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh index 2bcb55b..1f25eef 100644 --- a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh @@ -6,7 +6,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --width 832 \ --dataset_repeat 100 \ --model_id_with_origin_paths "iic/VACE-Wan2.1-1.3B-Preview:diffusion_pytorch_model*.safetensors,iic/VACE-Wan2.1-1.3B-Preview:models_t5_umt5-xxl-enc-bf16.pth,iic/VACE-Wan2.1-1.3B-Preview:Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ + --learning_rate 5e-5 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.1-VACE-1.3B-Preview_lora" \ @@ -14,4 +14,5 @@ accelerate launch examples/wanvideo/model_training/train.py \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "vace_video,vace_reference_image" \ - --use_gradient_checkpointing_offload \ No newline at end of file + --use_gradient_checkpointing_offload +# The learning rate is kept consistent with the settings in the original paper \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh index b565078..c8b77cc 100644 --- a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh @@ -6,7 +6,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --width 832 \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ + --learning_rate 5e-5 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.1-VACE-1.3B_lora" \ @@ -14,4 +14,5 @@ accelerate launch examples/wanvideo/model_training/train.py \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "vace_video,vace_reference_image" \ - --use_gradient_checkpointing_offload \ No newline at end of file + --use_gradient_checkpointing_offload +# The learning rate is kept consistent with the settings in the original paper \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh index 633ea0e..28bd05c 100644 --- a/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh @@ -7,7 +7,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --num_frames 17 \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-14B:Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ + --learning_rate 5e-5 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.1-VACE-14B_lora" \ @@ -15,4 +15,5 @@ accelerate launch examples/wanvideo/model_training/train.py \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "vace_video,vace_reference_image" \ - --use_gradient_checkpointing_offload \ No newline at end of file + --use_gradient_checkpointing_offload +# The learning rate is kept consistent with the settings in the original paper \ No newline at end of file From 70f531b724b6c5588a71334c27304dec7337f7be Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 19 Jan 2026 17:37:30 +0800 Subject: [PATCH 14/16] update wan-vace training scripts --- .../wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh index 93b38cf..916752b 100644 --- a/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh @@ -7,7 +7,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --num_frames 17 \ --dataset_repeat 100 \ --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ + --learning_rate 5e-5 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_lora" \ @@ -19,6 +19,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --max_timestep_boundary 0.358 \ --min_timestep_boundary 0 # boundary corresponds to timesteps [900, 1000] +# The learning rate is kept consistent with the settings in the original paper accelerate launch examples/wanvideo/model_training/train.py \ --dataset_base_path data/example_video_dataset \ @@ -29,7 +30,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --num_frames 17 \ --dataset_repeat 100 \ --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ + --learning_rate 5e-5 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_lora" \ @@ -40,4 +41,5 @@ accelerate launch examples/wanvideo/model_training/train.py \ --use_gradient_checkpointing_offload \ --max_timestep_boundary 1 \ --min_timestep_boundary 0.358 -# boundary corresponds to timesteps [0, 900] \ No newline at end of file +# boundary corresponds to timesteps [0, 900] +# The learning rate is kept consistent with the settings in the original paper \ No newline at end of file From 8ad2d9884bbaf450fc88dba917162d14bdbbe1ad Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 19 Jan 2026 17:43:07 +0800 Subject: [PATCH 15/16] update lr in wan-vace training scripts --- .../model_training/full/Wan2.1-VACE-1.3B-Preview.sh | 5 +++-- examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh | 5 +++-- examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh | 5 +++-- .../wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh | 8 +++++--- .../model_training/lora/Wan2.1-VACE-1.3B-Preview.sh | 5 ++--- examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh | 5 ++--- examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh | 5 ++--- .../wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh | 8 +++----- 8 files changed, 23 insertions(+), 23 deletions(-) diff --git a/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh index b348874..19b6ecb 100644 --- a/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh +++ b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh @@ -7,10 +7,11 @@ accelerate launch examples/wanvideo/model_training/train.py \ --num_frames 49 \ --dataset_repeat 100 \ --model_id_with_origin_paths "iic/VACE-Wan2.1-1.3B-Preview:diffusion_pytorch_model*.safetensors,iic/VACE-Wan2.1-1.3B-Preview:models_t5_umt5-xxl-enc-bf16.pth,iic/VACE-Wan2.1-1.3B-Preview:Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ + --learning_rate 5e-5 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.1-VACE-1.3B-Preview_full" \ --trainable_models "vace" \ --extra_inputs "vace_video,vace_reference_image" \ - --use_gradient_checkpointing_offload \ No newline at end of file + --use_gradient_checkpointing_offload +# The learning rate is kept consistent with the settings in the original paper \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh index 763252e..f9768c6 100644 --- a/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh +++ b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh @@ -7,10 +7,11 @@ accelerate launch examples/wanvideo/model_training/train.py \ --num_frames 49 \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ + --learning_rate 5e-5 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.1-VACE-1.3B_full" \ --trainable_models "vace" \ --extra_inputs "vace_video,vace_reference_image" \ - --use_gradient_checkpointing_offload \ No newline at end of file + --use_gradient_checkpointing_offload +# The learning rate is kept consistent with the settings in the original paper \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh b/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh index c549263..401a647 100644 --- a/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh +++ b/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh @@ -7,10 +7,11 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --num_frames 17 \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-14B:Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ + --learning_rate 5e-5 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.1-VACE-14B_full" \ --trainable_models "vace" \ --extra_inputs "vace_video,vace_reference_image" \ - --use_gradient_checkpointing_offload \ No newline at end of file + --use_gradient_checkpointing_offload +# The learning rate is kept consistent with the settings in the original paper \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh index ecfef32..ba3e875 100644 --- a/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh +++ b/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh @@ -7,7 +7,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --num_frames 17 \ --dataset_repeat 100 \ --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ + --learning_rate 5e-5 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_full" \ @@ -18,6 +18,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --min_timestep_boundary 0 \ --initialize_model_on_cpu # boundary corresponds to timesteps [900, 1000] +# The learning rate is kept consistent with the settings in the original paper accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ @@ -29,7 +30,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --num_frames 17 \ --dataset_repeat 100 \ --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ + --learning_rate 5e-5 \ --num_epochs 2 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_full" \ @@ -39,4 +40,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --max_timestep_boundary 1 \ --min_timestep_boundary 0.358 \ --initialize_model_on_cpu -# boundary corresponds to timesteps [0, 900] \ No newline at end of file +# boundary corresponds to timesteps [0, 900] +# The learning rate is kept consistent with the settings in the original paper \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh index 1f25eef..2bcb55b 100644 --- a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh @@ -6,7 +6,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --width 832 \ --dataset_repeat 100 \ --model_id_with_origin_paths "iic/VACE-Wan2.1-1.3B-Preview:diffusion_pytorch_model*.safetensors,iic/VACE-Wan2.1-1.3B-Preview:models_t5_umt5-xxl-enc-bf16.pth,iic/VACE-Wan2.1-1.3B-Preview:Wan2.1_VAE.pth" \ - --learning_rate 5e-5 \ + --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.1-VACE-1.3B-Preview_lora" \ @@ -14,5 +14,4 @@ accelerate launch examples/wanvideo/model_training/train.py \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "vace_video,vace_reference_image" \ - --use_gradient_checkpointing_offload -# The learning rate is kept consistent with the settings in the original paper \ No newline at end of file + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh index c8b77cc..b565078 100644 --- a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh @@ -6,7 +6,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --width 832 \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \ - --learning_rate 5e-5 \ + --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.1-VACE-1.3B_lora" \ @@ -14,5 +14,4 @@ accelerate launch examples/wanvideo/model_training/train.py \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "vace_video,vace_reference_image" \ - --use_gradient_checkpointing_offload -# The learning rate is kept consistent with the settings in the original paper \ No newline at end of file + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh index 28bd05c..633ea0e 100644 --- a/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh @@ -7,7 +7,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --num_frames 17 \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-14B:Wan2.1_VAE.pth" \ - --learning_rate 5e-5 \ + --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.1-VACE-14B_lora" \ @@ -15,5 +15,4 @@ accelerate launch examples/wanvideo/model_training/train.py \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "vace_video,vace_reference_image" \ - --use_gradient_checkpointing_offload -# The learning rate is kept consistent with the settings in the original paper \ No newline at end of file + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh index 916752b..93b38cf 100644 --- a/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh @@ -7,7 +7,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --num_frames 17 \ --dataset_repeat 100 \ --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ - --learning_rate 5e-5 \ + --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_lora" \ @@ -19,7 +19,6 @@ accelerate launch examples/wanvideo/model_training/train.py \ --max_timestep_boundary 0.358 \ --min_timestep_boundary 0 # boundary corresponds to timesteps [900, 1000] -# The learning rate is kept consistent with the settings in the original paper accelerate launch examples/wanvideo/model_training/train.py \ --dataset_base_path data/example_video_dataset \ @@ -30,7 +29,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --num_frames 17 \ --dataset_repeat 100 \ --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ - --learning_rate 5e-5 \ + --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_lora" \ @@ -41,5 +40,4 @@ accelerate launch examples/wanvideo/model_training/train.py \ --use_gradient_checkpointing_offload \ --max_timestep_boundary 1 \ --min_timestep_boundary 0.358 -# boundary corresponds to timesteps [0, 900] -# The learning rate is kept consistent with the settings in the original paper \ No newline at end of file +# boundary corresponds to timesteps [0, 900] \ No newline at end of file From 1e90c72d945aeb49cc533f67a19023b11d0fd0bf Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 19 Jan 2026 21:11:58 +0800 Subject: [PATCH 16/16] support klein base models --- README.md | 6 +++-- README_zh.md | 6 +++-- docs/en/Model_Details/FLUX2.md | 2 ++ docs/zh/Model_Details/FLUX2.md | 2 ++ .../model_inference/FLUX.2-klein-base-4B.py | 17 ++++++++++++ .../model_inference/FLUX.2-klein-base-9B.py | 17 ++++++++++++ .../FLUX.2-klein-base-4B.py | 27 +++++++++++++++++++ .../FLUX.2-klein-base-9B.py | 27 +++++++++++++++++++ .../full/FLUX.2-klein-base-4B.sh | 13 +++++++++ .../full/FLUX.2-klein-base-9B.sh | 13 +++++++++ .../lora/FLUX.2-klein-base-4B.sh | 15 +++++++++++ .../lora/FLUX.2-klein-base-9B.sh | 15 +++++++++++ .../validate_full/FLUX.2-klein-base-4B.py | 20 ++++++++++++++ .../validate_full/FLUX.2-klein-base-9B.py | 20 ++++++++++++++ .../validate_lora/FLUX.2-klein-base-4B.py | 18 +++++++++++++ .../validate_lora/FLUX.2-klein-base-9B.py | 18 +++++++++++++ 16 files changed, 232 insertions(+), 4 deletions(-) create mode 100644 examples/flux2/model_inference/FLUX.2-klein-base-4B.py create mode 100644 examples/flux2/model_inference/FLUX.2-klein-base-9B.py create mode 100644 examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py create mode 100644 examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py create mode 100644 examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh create mode 100644 examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh create mode 100644 examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh create mode 100644 examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh create mode 100644 examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py create mode 100644 examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py create mode 100644 examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py create mode 100644 examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py diff --git a/README.md b/README.md index a590402..4f3bb97 100644 --- a/README.md +++ b/README.md @@ -321,11 +321,13 @@ image.save("image.jpg") Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/) -| Model ID | Inference | Low-VRAM Inference | LoRA Training | LoRA Training Validation | -|-|-|-|-|-| +| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | +|-|-|-|-|-|-|-| |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| |[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| |[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| +|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)| +|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)| diff --git a/README_zh.md b/README_zh.md index 6a0e0ef..a464dab 100644 --- a/README_zh.md +++ b/README_zh.md @@ -321,11 +321,13 @@ image.save("image.jpg") FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/) -|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证| -|-|-|-|-|-| +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| |[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| |[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| +|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)| +|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)| diff --git a/docs/en/Model_Details/FLUX2.md b/docs/en/Model_Details/FLUX2.md index 70ccf21..89e3c92 100644 --- a/docs/en/Model_Details/FLUX2.md +++ b/docs/en/Model_Details/FLUX2.md @@ -64,6 +64,8 @@ image.save("image.jpg") |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| |[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| |[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| +|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)| +|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)| Special Training Scripts: diff --git a/docs/zh/Model_Details/FLUX2.md b/docs/zh/Model_Details/FLUX2.md index cf91054..896ad9f 100644 --- a/docs/zh/Model_Details/FLUX2.md +++ b/docs/zh/Model_Details/FLUX2.md @@ -64,6 +64,8 @@ image.save("image.jpg") |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| |[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| |[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| +|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)| +|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)| 特殊训练脚本: diff --git a/examples/flux2/model_inference/FLUX.2-klein-base-4B.py b/examples/flux2/model_inference/FLUX.2-klein-base-4B.py new file mode 100644 index 0000000..8ce4521 --- /dev/null +++ b/examples/flux2/model_inference/FLUX.2-klein-base-4B.py @@ -0,0 +1,17 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4) +image.save("image_FLUX.2-klein-base-4B.jpg") diff --git a/examples/flux2/model_inference/FLUX.2-klein-base-9B.py b/examples/flux2/model_inference/FLUX.2-klein-base-9B.py new file mode 100644 index 0000000..aa7193f --- /dev/null +++ b/examples/flux2/model_inference/FLUX.2-klein-base-9B.py @@ -0,0 +1,17 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4) +image.save("image_FLUX.2-klein-base-9B.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py new file mode 100644 index 0000000..733a006 --- /dev/null +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4) +image.save("image_FLUX.2-klein-base-4B.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py new file mode 100644 index 0000000..d5f5f80 --- /dev/null +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." +image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4) +image.save("image_FLUX.2-klein-base-9B.jpg") diff --git a/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh b/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh new file mode 100644 index 0000000..0862391 --- /dev/null +++ b/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh @@ -0,0 +1,13 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-base-4B_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh b/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh new file mode 100644 index 0000000..d33a21f --- /dev/null +++ b/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh @@ -0,0 +1,13 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-base-9B_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh b/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh new file mode 100644 index 0000000..e7f636e --- /dev/null +++ b/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh @@ -0,0 +1,15 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-base-4B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh b/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh new file mode 100644 index 0000000..d4f65df --- /dev/null +++ b/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh @@ -0,0 +1,15 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-base-9B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py new file mode 100644 index 0000000..95dcf9d --- /dev/null +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/FLUX.2-klein-base-4B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py new file mode 100644 index 0000000..c2a192d --- /dev/null +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/FLUX.2-klein-base-9B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") diff --git a/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py b/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py new file mode 100644 index 0000000..6694305 --- /dev/null +++ b/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-base-4B_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg") diff --git a/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py b/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py new file mode 100644 index 0000000..3551291 --- /dev/null +++ b/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-base-9B_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768) +image.save("image.jpg")