mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge branch 'main' of https://github.com/modelscope/DiffSynth-Studio into usp_npu
This commit is contained in:
2
.github/workflows/publish.yaml
vendored
2
.github/workflows/publish.yaml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
- name: Install wheel
|
||||
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
||||
- name: Build DiffSynth
|
||||
run: python setup.py sdist bdist_wheel
|
||||
run: python -m build
|
||||
- name: Publish package to PyPI
|
||||
run: |
|
||||
pip install twine
|
||||
|
||||
18
README.md
18
README.md
@@ -33,6 +33,12 @@ We believe that a well-developed open-source code framework can lower the thresh
|
||||
|
||||
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
||||
|
||||
- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available.
|
||||
|
||||
- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)).
|
||||
|
||||
- **December 24, 2025**: Based on Qwen-Image-Edit-2511, we trained an In-Context Editing LoRA model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)). This model takes three images as input (Image A, Image B, and Image C), and automatically analyzes the transformation from Image A to Image B, then applies the same transformation to Image C to generate Image D. For more details, please refer to our blog post ([Chinese version](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g), [English version](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)).
|
||||
|
||||
- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. For more details, please refer to our [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l).
|
||||
|
||||
- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online
|
||||
@@ -315,9 +321,13 @@ image.save("image.jpg")
|
||||
|
||||
Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
|
||||
|
||||
| Model ID | Inference | Low-VRAM Inference | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -401,6 +411,7 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
|
||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||
@@ -769,4 +780,3 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
17
README_zh.md
17
README_zh.md
@@ -33,6 +33,12 @@ DiffSynth 目前包括两个开源项目:
|
||||
|
||||
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||
|
||||
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
|
||||
|
||||
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog([中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
|
||||
|
||||
- **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)),这个模型可以输入三张图:图A、图B、图C,模型会自行分析图A到图B的变化,并将这样的变化应用到图C,生成图D。更多细节请阅读我们的 blog([中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。
|
||||
|
||||
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)(Image to LoRA)。这一模型以图像为输入,以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。
|
||||
|
||||
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
|
||||
@@ -315,9 +321,13 @@ image.save("image.jpg")
|
||||
|
||||
FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
|
||||
|
||||
|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -401,6 +411,7 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|
||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||
|
||||
@@ -481,6 +481,13 @@ flux_series = [
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||
"extra_kwargs": {"disable_guidance_embedder": True},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
|
||||
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
|
||||
"model_name": "flux_dit",
|
||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||
},
|
||||
]
|
||||
|
||||
flux2_series = [
|
||||
@@ -503,6 +510,28 @@ flux2_series = [
|
||||
"model_name": "flux2_vae",
|
||||
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
|
||||
"model_hash": "3bde7b817fec8143028b6825a63180df",
|
||||
"model_name": "flux2_dit",
|
||||
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
|
||||
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
|
||||
"model_name": "z_image_text_encoder",
|
||||
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||
"extra_kwargs": {"model_size": "8B"},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
|
||||
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
|
||||
"model_name": "flux2_dit",
|
||||
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
|
||||
},
|
||||
]
|
||||
|
||||
z_image_series = [
|
||||
|
||||
@@ -4,4 +4,3 @@ from .gradient import *
|
||||
from .loader import *
|
||||
from .vram import *
|
||||
from .device import *
|
||||
from .npu_patch import *
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
|
||||
from .npu_compatible_device import IS_NPU_AVAILABLE
|
||||
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
from diffsynth.core.device.npu_compatible_device import IS_NPU_AVAILABLE
|
||||
from .npu_autocast_patch import npu_autocast_patch
|
||||
|
||||
if IS_NPU_AVAILABLE:
|
||||
npu_autocast_patch()
|
||||
@@ -1,21 +0,0 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
def npu_autocast_patch_wrapper(func):
|
||||
@contextmanager
|
||||
def wrapper(*args, **kwargs):
|
||||
flag = False
|
||||
if "npu" in args or ("device_type" in kwargs and kwargs["device_type"] == "npu"):
|
||||
if torch.float32 in args or ("dtype" in kwargs and kwargs["dtype"] == torch.float32):
|
||||
flag = True
|
||||
with func(*args, **kwargs) as ctx:
|
||||
if flag:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
yield ctx
|
||||
return wrapper
|
||||
|
||||
|
||||
def npu_autocast_patch():
|
||||
torch.amp.autocast = npu_autocast_patch_wrapper(torch.amp.autocast)
|
||||
torch.autocast = npu_autocast_patch_wrapper(torch.autocast)
|
||||
@@ -1,4 +1,4 @@
|
||||
import torch, json
|
||||
import torch, json, os
|
||||
from ..core import ModelConfig, load_state_dict
|
||||
from ..utils.controlnet import ControlNetInput
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
@@ -127,15 +127,29 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
for model_id_with_origin_path in model_id_with_origin_paths:
|
||||
model_id, origin_file_pattern = model_id_with_origin_path.split(":")
|
||||
vram_config = self.parse_vram_config(
|
||||
fp8=model_id_with_origin_path in fp8_models,
|
||||
offload=model_id_with_origin_path in offload_models,
|
||||
device=device
|
||||
)
|
||||
model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config))
|
||||
config = self.parse_path_or_model_id(model_id_with_origin_path)
|
||||
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
|
||||
return model_configs
|
||||
|
||||
|
||||
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
|
||||
if model_id_with_origin_path is None:
|
||||
return default_value
|
||||
elif os.path.exists(model_id_with_origin_path):
|
||||
return ModelConfig(path=model_id_with_origin_path)
|
||||
else:
|
||||
if ":" not in model_id_with_origin_path:
|
||||
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
|
||||
split_id = model_id_with_origin_path.rfind(":")
|
||||
model_id = model_id_with_origin_path[:split_id]
|
||||
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
|
||||
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
|
||||
|
||||
|
||||
def switch_pipe_to_training_mode(
|
||||
self,
|
||||
|
||||
@@ -823,7 +823,13 @@ class Flux2PosEmbed(nn.Module):
|
||||
|
||||
|
||||
class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
||||
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 256,
|
||||
embedding_dim: int = 6144,
|
||||
bias: bool = False,
|
||||
guidance_embeds: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
@@ -831,20 +837,24 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
|
||||
self.guidance_embedder = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
if guidance_embeds:
|
||||
self.guidance_embedder = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
else:
|
||||
self.guidance_embedder = None
|
||||
|
||||
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
||||
|
||||
time_guidance_emb = timesteps_emb + guidance_emb
|
||||
|
||||
return time_guidance_emb
|
||||
if guidance is not None and self.guidance_embedder is not None:
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
||||
time_guidance_emb = timesteps_emb + guidance_emb
|
||||
return time_guidance_emb
|
||||
else:
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class Flux2Modulation(nn.Module):
|
||||
@@ -882,6 +892,7 @@ class Flux2DiT(torch.nn.Module):
|
||||
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
||||
rope_theta: int = 2000,
|
||||
eps: float = 1e-6,
|
||||
guidance_embeds: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
@@ -892,7 +903,10 @@ class Flux2DiT(torch.nn.Module):
|
||||
|
||||
# 2. Combined timestep + guidance embedding
|
||||
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
||||
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
|
||||
in_channels=timestep_guidance_channels,
|
||||
embedding_dim=self.inner_dim,
|
||||
bias=False,
|
||||
guidance_embeds=guidance_embeds,
|
||||
)
|
||||
|
||||
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
||||
@@ -953,34 +967,9 @@ class Flux2DiT(torch.nn.Module):
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
) -> Union[torch.Tensor]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
):
|
||||
# 0. Handle input arguments
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
@@ -992,7 +981,9 @@ class Flux2DiT(torch.nn.Module):
|
||||
|
||||
# 1. Calculate timestep embedding and modulation parameters
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
temb = self.time_guidance_embed(timestep, guidance)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from .wan_video_dit import flash_attention
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
@@ -375,8 +375,6 @@ class FinalLayer_FP32(nn.Module):
|
||||
T, _, _ = latent_shape
|
||||
|
||||
with amp.autocast(get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
||||
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
||||
x = self.linear(x)
|
||||
@@ -587,8 +585,6 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
|
||||
# compute modulation params in fp32
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
shift_msa, scale_msa, gate_msa, \
|
||||
shift_mlp, scale_mlp, gate_mlp = \
|
||||
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
||||
@@ -608,8 +604,6 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
x_s = attn_outputs
|
||||
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
@@ -623,8 +617,6 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
||||
x_s = self.ffn(x_m)
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
@@ -807,8 +799,6 @@ class LongCatVideoTransformer3DModel(torch.nn.Module):
|
||||
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
||||
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
||||
|
||||
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
||||
|
||||
@@ -5,7 +5,6 @@ import math
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .wan_video_camera_controller import SimpleAdapter
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
@@ -94,6 +93,7 @@ def rope_apply(x, freqs, num_heads):
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs
|
||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from torch.nn import RMSNorm
|
||||
from .general_modules import RMSNorm
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
@@ -3,38 +3,71 @@ import torch
|
||||
|
||||
|
||||
class ZImageTextEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, model_size="4B"):
|
||||
super().__init__()
|
||||
config = Qwen3Config(**{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2560,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9728,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 36,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": True,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.0",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936
|
||||
})
|
||||
config_dict = {
|
||||
"4B": Qwen3Config(**{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2560,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9728,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 36,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": True,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.0",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936
|
||||
}),
|
||||
"8B": Qwen3Config(**{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 4096,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 12288,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 36,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": False,
|
||||
"transformers_version": "4.56.1",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936
|
||||
})
|
||||
}
|
||||
config = config_dict[model_size]
|
||||
self.model = Qwen3Model(config)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
||||
@@ -11,10 +11,11 @@ from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||
|
||||
from transformers import AutoProcessor
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from ..models.flux2_text_encoder import Flux2TextEncoder
|
||||
from ..models.flux2_dit import Flux2DiT
|
||||
from ..models.flux2_vae import Flux2VAE
|
||||
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||
|
||||
|
||||
class Flux2ImagePipeline(BasePipeline):
|
||||
@@ -26,6 +27,7 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||
self.text_encoder: Flux2TextEncoder = None
|
||||
self.text_encoder_qwen3: ZImageTextEncoder = None
|
||||
self.dit: Flux2DiT = None
|
||||
self.vae: Flux2VAE = None
|
||||
self.tokenizer: AutoProcessor = None
|
||||
@@ -33,8 +35,10 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
self.units = [
|
||||
Flux2Unit_ShapeChecker(),
|
||||
Flux2Unit_PromptEmbedder(),
|
||||
Flux2Unit_Qwen3PromptEmbedder(),
|
||||
Flux2Unit_NoiseInitializer(),
|
||||
Flux2Unit_InputImageEmbedder(),
|
||||
Flux2Unit_EditImageEmbedder(),
|
||||
Flux2Unit_ImageIDs(),
|
||||
]
|
||||
self.model_fn = model_fn_flux2
|
||||
@@ -54,11 +58,12 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder")
|
||||
pipe.text_encoder_qwen3 = model_pool.fetch_model("z_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("flux2_dit")
|
||||
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoProcessor.from_pretrained(tokenizer_config.path)
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
@@ -76,6 +81,9 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Edit
|
||||
edit_image: Union[Image.Image, List[Image.Image]] = None,
|
||||
edit_image_auto_resize: bool = True,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
@@ -99,6 +107,7 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
@@ -276,6 +285,10 @@ class Flux2Unit_PromptEmbedder(PipelineUnit):
|
||||
return prompt_embeds, text_ids
|
||||
|
||||
def process(self, pipe: Flux2ImagePipeline, prompt):
|
||||
# Skip if Qwen3 text encoder is available (handled by Qwen3PromptEmbedder)
|
||||
if pipe.text_encoder_qwen3 is not None:
|
||||
return {}
|
||||
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_embeds, text_ids = self.encode_prompt(
|
||||
pipe.text_encoder, pipe.tokenizer, prompt,
|
||||
@@ -284,6 +297,136 @@ class Flux2Unit_PromptEmbedder(PipelineUnit):
|
||||
return {"prompt_embeds": prompt_embeds, "text_ids": text_ids}
|
||||
|
||||
|
||||
class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_emb", "prompt_emb_mask"),
|
||||
onload_model_names=("text_encoder_qwen3",)
|
||||
)
|
||||
self.hidden_states_layers = (9, 18, 27) # Qwen3 layers
|
||||
|
||||
def get_qwen3_prompt_embeds(
|
||||
self,
|
||||
text_encoder: ZImageTextEncoder,
|
||||
tokenizer: AutoTokenizer,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
with torch.inference_mode():
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
return prompt_embeds
|
||||
|
||||
def prepare_text_ids(
|
||||
self,
|
||||
x: torch.Tensor, # (B, L, D) or (L, D)
|
||||
t_coord: Optional[torch.Tensor] = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: ZImageTextEncoder,
|
||||
tokenizer: AutoTokenizer,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self.get_qwen3_prompt_embeds(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
text_ids = self.prepare_text_ids(prompt_embeds)
|
||||
text_ids = text_ids.to(device)
|
||||
return prompt_embeds, text_ids
|
||||
|
||||
def process(self, pipe: Flux2ImagePipeline, prompt):
|
||||
# Check if Qwen3 text encoder is available
|
||||
if pipe.text_encoder_qwen3 is None:
|
||||
return {}
|
||||
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_embeds, text_ids = self.encode_prompt(
|
||||
pipe.text_encoder_qwen3, pipe.tokenizer, prompt,
|
||||
dtype=pipe.torch_dtype, device=pipe.device,
|
||||
)
|
||||
return {"prompt_embeds": prompt_embeds, "text_ids": text_ids}
|
||||
|
||||
|
||||
class Flux2Unit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -319,6 +462,64 @@ class Flux2Unit_InputImageEmbedder(PipelineUnit):
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
class Flux2Unit_EditImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image", "edit_image_auto_resize"),
|
||||
output_params=("edit_latents", "edit_image_ids"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def calculate_dimensions(self, target_area, ratio):
|
||||
import math
|
||||
width = math.sqrt(target_area * ratio)
|
||||
height = width / ratio
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
return width, height
|
||||
|
||||
def edit_image_auto_resize(self, edit_image):
|
||||
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
|
||||
return edit_image.resize((calculated_width, calculated_height))
|
||||
|
||||
def process_image_ids(self, image_latents, scale=10):
|
||||
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||
t_coords = [t.view(-1) for t in t_coords]
|
||||
|
||||
image_latent_ids = []
|
||||
for x, t in zip(image_latents, t_coords):
|
||||
x = x.squeeze(0)
|
||||
_, height, width = x.shape
|
||||
|
||||
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||
image_latent_ids.append(x_ids)
|
||||
|
||||
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||
|
||||
return image_latent_ids
|
||||
|
||||
def process(self, pipe: Flux2ImagePipeline, edit_image, edit_image_auto_resize):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if isinstance(edit_image, Image.Image):
|
||||
edit_image = [edit_image]
|
||||
resized_edit_image, edit_latents = [], []
|
||||
for image in edit_image:
|
||||
# Preprocess
|
||||
if edit_image_auto_resize is None or edit_image_auto_resize:
|
||||
image = self.edit_image_auto_resize(image)
|
||||
resized_edit_image.append(image)
|
||||
# Encode
|
||||
image = pipe.preprocess_image(image)
|
||||
latents = pipe.vae.encode(image)
|
||||
edit_latents.append(latents)
|
||||
edit_image_ids = self.process_image_ids(edit_latents).to(pipe.device)
|
||||
edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1)
|
||||
return {"edit_latents": edit_latents, "edit_image_ids": edit_image_ids}
|
||||
|
||||
|
||||
class Flux2Unit_ImageIDs(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -353,10 +554,17 @@ def model_fn_flux2(
|
||||
prompt_embeds=None,
|
||||
text_ids=None,
|
||||
image_ids=None,
|
||||
edit_latents=None,
|
||||
edit_image_ids=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
image_seq_len = latents.shape[1]
|
||||
if edit_latents is not None:
|
||||
image_seq_len = latents.shape[1]
|
||||
latents = torch.concat([latents, edit_latents], dim=1)
|
||||
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
|
||||
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
|
||||
model_output = dit(
|
||||
hidden_states=latents,
|
||||
@@ -368,4 +576,5 @@ def model_fn_flux2(
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
model_output = model_output[:, :image_seq_len]
|
||||
return model_output
|
||||
|
||||
@@ -123,11 +123,15 @@ class WanVideoPipeline(BasePipeline):
|
||||
model_config.model_id = redirect_dict[model_config.origin_file_pattern][0]
|
||||
model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1]
|
||||
|
||||
# Initialize pipeline
|
||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
if use_usp:
|
||||
from ..utils.xfuser import initialize_usp
|
||||
initialize_usp(device)
|
||||
import torch.distributed as dist
|
||||
from ..core.device.npu_compatible_device import get_device_name
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
device = get_device_name()
|
||||
# Initialize pipeline
|
||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
|
||||
@@ -143,6 +143,8 @@ def FluxDiTStateDictConverterFromDiffusers(state_dict):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in global_rename_dict:
|
||||
if global_rename_dict[prefix] == "final_norm_out.linear":
|
||||
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
def ZImageTextEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name != "lm_head.weight":
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -5,7 +5,7 @@ from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
from ...core.device import parse_nccl_backend, parse_device_type, IS_NPU_AVAILABLE
|
||||
from ...core.device import parse_nccl_backend, parse_device_type
|
||||
|
||||
|
||||
def initialize_usp(device_type):
|
||||
@@ -50,6 +50,7 @@ def rope_apply(x, freqs, num_heads):
|
||||
sp_rank = get_sequence_parallel_rank()
|
||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank
|
||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
|
||||
@@ -2,6 +2,15 @@
|
||||
|
||||
FLUX.2 is an image generation model trained and open-sourced by Black Forest Labs.
|
||||
|
||||
## Model Lineage
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
FLUX.2-Series-->black-forest-labs/FLUX.2-dev;
|
||||
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B;
|
||||
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B;
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
Before using this project for model inference and training, please install DiffSynth-Studio first.
|
||||
@@ -50,16 +59,20 @@ image.save("image.jpg")
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Model ID | Inference | Low VRAM Inference | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - |
|
||||
| [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) | [code](/examples/flux2/model_inference/FLUX.2-dev.py) | [code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py) | [code](/examples/flux2/model_training/lora/FLUX.2-dev.sh) | [code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py) |
|
||||
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - | - | - |
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||
|
||||
Special Training Scripts:
|
||||
|
||||
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/)
|
||||
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/)
|
||||
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/)
|
||||
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
|
||||
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md)
|
||||
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md)
|
||||
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md)
|
||||
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md)
|
||||
|
||||
## Model Inference
|
||||
|
||||
@@ -135,4 +148,4 @@ We have built a sample image dataset for your testing. You can download this dat
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
||||
|
||||
@@ -86,6 +86,7 @@ graph LR;
|
||||
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
|
||||
|
||||
@@ -59,7 +59,7 @@ save_video(video, "video.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
### Training
|
||||
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_scripts`, for example `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`.
|
||||
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.
|
||||
|
||||
In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models.
|
||||
|
||||
|
||||
@@ -2,6 +2,15 @@
|
||||
|
||||
FLUX.2 是由 Black Forest Labs 训练并开源的图像生成模型。
|
||||
|
||||
## 模型血缘
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
FLUX.2-Series-->black-forest-labs/FLUX.2-dev;
|
||||
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B;
|
||||
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B;
|
||||
```
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||
@@ -50,16 +59,20 @@ image.save("image.jpg")
|
||||
|
||||
## 模型总览
|
||||
|
||||
|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||
|
||||
特殊训练脚本:
|
||||
|
||||
* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/flux/model_training/special/differential_training/)
|
||||
* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/flux/model_training/special/fp8_training/)
|
||||
* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/flux/model_training/special/split_training/)
|
||||
* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
|
||||
* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)
|
||||
* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)
|
||||
* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)
|
||||
* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)
|
||||
|
||||
## 模型推理
|
||||
|
||||
@@ -135,4 +148,4 @@ FLUX.2 系列模型统一通过 [`examples/flux2/model_training/train.py`](/exam
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。
|
||||
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。
|
||||
|
||||
@@ -86,6 +86,7 @@ graph LR;
|
||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||
|
||||
@@ -59,7 +59,7 @@ save_video(video, "video.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
### 训练
|
||||
当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_scripts`目录下,例如 `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`。
|
||||
当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`。
|
||||
|
||||
在NPU训练脚本中,添加了可以优化性能的NPU特有环境变量,并针对特定模型开启了相关参数。
|
||||
|
||||
|
||||
21
examples/flux2/model_inference/FLUX.2-klein-4B.py
Normal file
21
examples/flux2/model_inference/FLUX.2-klein-4B.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_FLUX.2-klein-4B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_edit_FLUX.2-klein-4B.jpg")
|
||||
21
examples/flux2/model_inference/FLUX.2-klein-9B.py
Normal file
21
examples/flux2/model_inference/FLUX.2-klein-9B.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_FLUX.2-klein-9B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_edit_FLUX.2-klein-9B.jpg")
|
||||
21
examples/flux2/model_inference/FLUX.2-klein-base-4B.py
Normal file
21
examples/flux2/model_inference/FLUX.2-klein-base-4B.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_FLUX.2-klein-base-4B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_edit_FLUX.2-klein-base-4B.jpg")
|
||||
21
examples/flux2/model_inference/FLUX.2-klein-base-9B.py
Normal file
21
examples/flux2/model_inference/FLUX.2-klein-base-9B.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_FLUX.2-klein-base-9B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_edit_FLUX.2-klein-base-9B.jpg")
|
||||
31
examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py
Normal file
31
examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_FLUX.2-klein-4B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_edit_FLUX.2-klein-4B.jpg")
|
||||
31
examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py
Normal file
31
examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_FLUX.2-klein-9B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_edit_FLUX.2-klein-9B.jpg")
|
||||
@@ -0,0 +1,31 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_FLUX.2-klein-base-4B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_edit_FLUX.2-klein-base-4B.jpg")
|
||||
@@ -0,0 +1,31 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_FLUX.2-klein-base-9B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_edit_FLUX.2-klein-base-9B.jpg")
|
||||
30
examples/flux2/model_training/full/FLUX.2-klein-4B.sh
Normal file
30
examples/flux2/model_training/full/FLUX.2-klein-4B.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-4B_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
# --learning_rate 1e-5 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-4B_full" \
|
||||
# --trainable_models "dit" \
|
||||
# --use_gradient_checkpointing
|
||||
31
examples/flux2/model_training/full/FLUX.2-klein-9B.sh
Normal file
31
examples/flux2/model_training/full/FLUX.2-klein-9B.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
# This script is tested on 8*A100
|
||||
accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-9B_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
# --learning_rate 1e-5 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-9B_full" \
|
||||
# --trainable_models "dit" \
|
||||
# --use_gradient_checkpointing
|
||||
30
examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh
Normal file
30
examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-base-4B_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
# --learning_rate 1e-5 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-base-4B_full" \
|
||||
# --trainable_models "dit" \
|
||||
# --use_gradient_checkpointing
|
||||
31
examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh
Normal file
31
examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
# This script is tested on 8*A100
|
||||
accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-base-9B_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
# --learning_rate 1e-5 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-base-9B_full" \
|
||||
# --trainable_models "dit" \
|
||||
# --use_gradient_checkpointing
|
||||
22
examples/flux2/model_training/full/accelerate_config.yaml
Normal file
22
examples/flux2/model_training/full/accelerate_config.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
34
examples/flux2/model_training/lora/FLUX.2-klein-4B.sh
Normal file
34
examples/flux2/model_training/lora/FLUX.2-klein-4B.sh
Normal file
@@ -0,0 +1,34 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-4B_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
# --learning_rate 1e-4 \
|
||||
# --num_epochs 5 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-4B_lora" \
|
||||
# --lora_base_model "dit" \
|
||||
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
|
||||
# --lora_rank 32 \
|
||||
# --use_gradient_checkpointing
|
||||
34
examples/flux2/model_training/lora/FLUX.2-klein-9B.sh
Normal file
34
examples/flux2/model_training/lora/FLUX.2-klein-9B.sh
Normal file
@@ -0,0 +1,34 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-9B_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
# --learning_rate 1e-4 \
|
||||
# --num_epochs 5 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-9B_lora" \
|
||||
# --lora_base_model "dit" \
|
||||
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
|
||||
# --lora_rank 32 \
|
||||
# --use_gradient_checkpointing
|
||||
34
examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh
Normal file
34
examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh
Normal file
@@ -0,0 +1,34 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-base-4B_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
# --learning_rate 1e-4 \
|
||||
# --num_epochs 5 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-base-4B_lora" \
|
||||
# --lora_base_model "dit" \
|
||||
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
|
||||
# --lora_rank 32 \
|
||||
# --use_gradient_checkpointing
|
||||
34
examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh
Normal file
34
examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh
Normal file
@@ -0,0 +1,34 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-base-9B_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
# --learning_rate 1e-4 \
|
||||
# --num_epochs 5 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-base-9B_lora" \
|
||||
# --lora_base_model "dit" \
|
||||
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
|
||||
# --lora_rank 32 \
|
||||
# --use_gradient_checkpointing
|
||||
@@ -24,7 +24,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||
tokenizer_config = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"))
|
||||
self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
from diffsynth.core import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
state_dict = load_state_dict("./models/train/FLUX.2-klein-4B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,20 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
from diffsynth.core import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
state_dict = load_state_dict("./models/train/FLUX.2-klein-9B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,20 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
from diffsynth.core import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
state_dict = load_state_dict("./models/train/FLUX.2-klein-base-4B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,20 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
from diffsynth.core import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
state_dict = load_state_dict("./models/train/FLUX.2-klein-base-9B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,18 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-4B_lora/epoch-4.safetensors")
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,18 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-9B_lora/epoch-4.safetensors")
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,18 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-base-4B_lora/epoch-4.safetensors")
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,18 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-base-9B_lora/epoch-4.safetensors")
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,34 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from modelscope import snapshot_download
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Layered-Control", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
|
||||
snapshot_download(
|
||||
model_id="DiffSynth-Studio/Qwen-Image-Layered-Control",
|
||||
allow_file_pattern="assets/image_1_input.png",
|
||||
local_dir="data/layered_input"
|
||||
)
|
||||
|
||||
prompt = "A cartoon skeleton character wearing a purple hat and holding a gift box"
|
||||
input_image = Image.open("data/layered_input/assets/image_1_input.png").convert("RGBA").resize((1024, 1024))
|
||||
images = pipe(
|
||||
prompt,
|
||||
seed=0,
|
||||
num_inference_steps=30, cfg_scale=4,
|
||||
height=1024, width=1024,
|
||||
layer_input_image=input_image,
|
||||
layer_num=0,
|
||||
)
|
||||
images[0].save("image.png")
|
||||
@@ -0,0 +1,44 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from modelscope import snapshot_download
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Layered-Control", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
|
||||
snapshot_download(
|
||||
model_id="DiffSynth-Studio/Qwen-Image-Layered-Control",
|
||||
allow_file_pattern="assets/image_1_input.png",
|
||||
local_dir="data/layered_input"
|
||||
)
|
||||
|
||||
prompt = "A cartoon skeleton character wearing a purple hat and holding a gift box"
|
||||
input_image = Image.open("data/layered_input/assets/image_1_input.png").convert("RGBA").resize((1024, 1024))
|
||||
images = pipe(
|
||||
prompt,
|
||||
seed=0,
|
||||
num_inference_steps=30, cfg_scale=4,
|
||||
height=1024, width=1024,
|
||||
layer_input_image=input_image,
|
||||
layer_num=0,
|
||||
)
|
||||
images[0].save("image.png")
|
||||
@@ -0,0 +1,18 @@
|
||||
# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer
|
||||
|
||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset/layer \
|
||||
--dataset_metadata_path data/example_image_dataset/layer/metadata_layered_control.json \
|
||||
--data_file_keys "image,layer_input_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "DiffSynth-Studio/Qwen-Image-Layered-Control:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Layered-Control_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "layer_num,layer_input_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters
|
||||
@@ -0,0 +1,20 @@
|
||||
# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer
|
||||
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset/layer \
|
||||
--dataset_metadata_path data/example_image_dataset/layer/metadata_layered_control.json \
|
||||
--data_file_keys "image,layer_input_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "DiffSynth-Studio/Qwen-Image-Layered-Control:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Layered-Control_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "layer_num,layer_input_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters
|
||||
@@ -0,0 +1,26 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Layered-Control", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Qwen-Image-Layered-Control_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
prompt = "Text 'HELLO' and 'Have a great day'"
|
||||
input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480))
|
||||
images = pipe(
|
||||
prompt, seed=0,
|
||||
height=480, width=864,
|
||||
layer_input_image=input_image, layer_num=0,
|
||||
)
|
||||
images[0].save("image.png")
|
||||
@@ -0,0 +1,25 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Layered-Control", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Layered-Control_lora/epoch-4.safetensors")
|
||||
prompt = "Text 'HELLO' and 'Have a great day'"
|
||||
input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480))
|
||||
images = pipe(
|
||||
prompt, seed=0,
|
||||
height=480, width=864,
|
||||
layer_input_image=input_image, layer_num=0,
|
||||
)
|
||||
images[0].save("image.png")
|
||||
@@ -7,10 +7,11 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "iic/VACE-Wan2.1-1.3B-Preview:diffusion_pytorch_model*.safetensors,iic/VACE-Wan2.1-1.3B-Preview:models_t5_umt5-xxl-enc-bf16.pth,iic/VACE-Wan2.1-1.3B-Preview:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--learning_rate 5e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.vace." \
|
||||
--output_path "./models/train/Wan2.1-VACE-1.3B-Preview_full" \
|
||||
--trainable_models "vace" \
|
||||
--extra_inputs "vace_video,vace_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
--use_gradient_checkpointing_offload
|
||||
# The learning rate is kept consistent with the settings in the original paper
|
||||
@@ -7,10 +7,11 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--learning_rate 5e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.vace." \
|
||||
--output_path "./models/train/Wan2.1-VACE-1.3B_full" \
|
||||
--trainable_models "vace" \
|
||||
--extra_inputs "vace_video,vace_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
--use_gradient_checkpointing_offload
|
||||
# The learning rate is kept consistent with the settings in the original paper
|
||||
@@ -7,10 +7,11 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--num_frames 17 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--learning_rate 5e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.vace." \
|
||||
--output_path "./models/train/Wan2.1-VACE-14B_full" \
|
||||
--trainable_models "vace" \
|
||||
--extra_inputs "vace_video,vace_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
--use_gradient_checkpointing_offload
|
||||
# The learning rate is kept consistent with the settings in the original paper
|
||||
@@ -7,7 +7,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--num_frames 17 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--learning_rate 5e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.vace." \
|
||||
--output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_full" \
|
||||
@@ -18,6 +18,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--min_timestep_boundary 0 \
|
||||
--initialize_model_on_cpu
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
# The learning rate is kept consistent with the settings in the original paper
|
||||
|
||||
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
@@ -29,7 +30,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--num_frames 17 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--learning_rate 5e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.vace." \
|
||||
--output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_full" \
|
||||
@@ -39,4 +40,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.358 \
|
||||
--initialize_model_on_cpu
|
||||
# boundary corresponds to timesteps [0, 900]
|
||||
# boundary corresponds to timesteps [0, 900]
|
||||
# The learning rate is kept consistent with the settings in the original paper
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "diffsynth"
|
||||
version = "2.0.1"
|
||||
version = "2.0.2"
|
||||
description = "Enjoy the magic of Diffusion models!"
|
||||
authors = [{name = "ModelScope Team"}]
|
||||
license = {text = "Apache-2.0"}
|
||||
|
||||
Reference in New Issue
Block a user