Compare commits

...

19 Commits

Author SHA1 Message Date
Artiprocher
fddc98ff16 fix mix-precision issues in low-version torch 2026-02-10 11:12:50 +08:00
Artiprocher
ff10fde47f update lora loading in docs 2026-02-10 10:48:44 +08:00
Zhongjie Duan
dc94614c80 Merge pull request #1256 from Feng0w0/npu_fused
[model][NPU]:Add NPU fusion operator patch to Zimage model to improve performance
2026-02-09 20:08:44 +08:00
feng0w0
e56a4d5730 [model][NPU]:Add NPU fusion operator patch to Zimage model to improve performance 2026-02-09 12:31:34 +08:00
feng0w0
3f8468893a [model][NPU]:Add NPU fusion operator patch to Zimage model to improve performance 2026-02-09 09:51:06 +08:00
Zhongjie Duan
1b47e1dc22 Merge pull request #1272 from modelscope/zero3-fix
Support DeepSpeed ZeRO 3
2026-02-06 16:33:12 +08:00
Artiprocher
b0bf78e915 refine code & doc 2026-02-06 16:27:23 +08:00
Zhongjie Duan
abdf66d09e Merge pull request #1265 from lzws/main
fix wanS2V bug and update readme
2026-02-06 10:22:48 +08:00
lzws
27b1fe240b add examples 2026-02-05 17:17:10 +08:00
lzws
1635897516 update readme 2026-02-05 16:56:39 +08:00
lzws
8d172127cd fix wans2v bug and update readme 2026-02-05 16:52:38 +08:00
feng0w0
fccb1ecdd7 Initialize qwen-image on CPU 2026-02-05 11:54:36 +08:00
Zhongjie Duan
c0f7e1db7c Merge pull request #1261 from modelscope/examples-update
update examples
2026-02-05 11:11:35 +08:00
feng0w0
6886f7ba35 fix wan decoder bug 2026-02-05 10:31:41 +08:00
feng0w0
051b957adb [model][NPU] Add NPU fusion operator patch to Zimage model to improve performance 2026-02-03 19:50:21 +08:00
feng0w0
ca9b5e64ea [feature]:Add adaptation of all models to zero3 2026-02-03 15:44:53 +08:00
feng0w0
2070bbd925 [feature]:Add adaptation of all models to zero3 2026-01-31 16:50:18 +08:00
feng0w0
3140199c96 [feature]:Add adaptation of all models to zero3 2026-01-27 15:33:42 +08:00
feng0w0
4e9db263b0 [feature]:Add adaptation of all models to zero3 2026-01-27 11:24:43 +08:00
43 changed files with 740 additions and 215 deletions

View File

@@ -760,6 +760,37 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements. DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
<details>
<summary>Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation</summary>
- Paper: [Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
](https://arxiv.org/abs/2602.03208)
- Sample Code: coming soon
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/5be15dc6-2805-4822-b04c-2573fc0f45f0)|![Image](https://github.com/user-attachments/assets/e71b8c20-1629-41d9-b0ff-185805c1da4e)|![Image](https://github.com/user-attachments/assets/7a73c968-133a-4545-9aa2-205533861cd4)|![Image](https://github.com/user-attachments/assets/c8390b22-14fe-48a0-a6e6-d6556d31235e)|
</details>
<details>
<summary>VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers</summary>
- Paper: [VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers
](https://arxiv.org/abs/2602.03210)
- Sample code: [/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)
|Example 1|Example 2|Query|Output|
|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/380d2670-47bf-41cd-b5c9-37110cc4a943)|![Image](https://github.com/user-attachments/assets/7ceaf345-0992-46e6-b38f-394c2065b165)|![Image](https://github.com/user-attachments/assets/f7c26c21-6894-4d9e-b570-f1d44ca7c1de)|![Image](https://github.com/user-attachments/assets/c2bebe3b-5984-41ba-94bf-9509f6a8a990)|
</details>
<details> <details>
<summary>AttriCtrl: Attribute Intensity Control for Image Generation Models</summary> <summary>AttriCtrl: Attribute Intensity Control for Image Generation Models</summary>
@@ -770,7 +801,7 @@ DiffSynth-Studio is not just an engineered model framework, but also an incubato
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9| |brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|-|-|-|-|-| |-|-|-|-|-|
|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)| |![Image](https://github.com/user-attachments/assets/e74b32a5-5b2e-4c87-9df8-487c0f8366b7)|![Image](https://github.com/user-attachments/assets/bfe8bec2-9e55-493d-9a26-7e9cce28e03d)|![Image](https://github.com/user-attachments/assets/b099dfe3-ff1f-4b96-894c-d48bbe92db7a)|![Image](https://github.com/user-attachments/assets/0a6b2982-deab-4b0d-91ad-888782de01c9)|![Image](https://github.com/user-attachments/assets/fcecb755-7d03-4020-b83a-13ad2b38705c)|
</details> </details>
@@ -785,10 +816,10 @@ DiffSynth-Studio is not just an engineered model framework, but also an incubato
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)| ||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|-|-|-|-|-| |-|-|-|-|-|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)| |[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![Image](https://github.com/user-attachments/assets/01c54d5a-4f00-4c2e-982a-4ec0a4c6a6e3)|![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)| |[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/43720a9f-aa27-4918-947d-545389375d46)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)| |[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/041a3f9a-c7b4-4311-8582-cb71a7226d80)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)| |[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|![Image](https://github.com/user-attachments/assets/a640fd54-3192-49a0-9281-b43d9ba64f09)|
</details> </details>

View File

@@ -760,6 +760,37 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。 DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
<details>
<summary>Spectral Evolution Search: 用于奖励对齐图像生成的高效推理阶段缩放</summary>
- 论文:[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
](https://arxiv.org/abs/2602.03208)
- 代码样例coming soon
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/5be15dc6-2805-4822-b04c-2573fc0f45f0)|![Image](https://github.com/user-attachments/assets/e71b8c20-1629-41d9-b0ff-185805c1da4e)|![Image](https://github.com/user-attachments/assets/7a73c968-133a-4545-9aa2-205533861cd4)|![Image](https://github.com/user-attachments/assets/c8390b22-14fe-48a0-a6e6-d6556d31235e)|
</details>
<details>
<summary>VIRAL基于DiT模型的类比视觉上下文推理</summary>
- 论文:[VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers
](https://arxiv.org/abs/2602.03210)
- 代码样例:[/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)
|Example 1|Example 2|Query|Output|
|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/380d2670-47bf-41cd-b5c9-37110cc4a943)|![Image](https://github.com/user-attachments/assets/7ceaf345-0992-46e6-b38f-394c2065b165)|![Image](https://github.com/user-attachments/assets/f7c26c21-6894-4d9e-b570-f1d44ca7c1de)|![Image](https://github.com/user-attachments/assets/c2bebe3b-5984-41ba-94bf-9509f6a8a990)|
</details>
<details> <details>
<summary>AttriCtrl: 图像生成模型的属性强度控制</summary> <summary>AttriCtrl: 图像生成模型的属性强度控制</summary>
@@ -771,7 +802,7 @@ DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9| |brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|-|-|-|-|-| |-|-|-|-|-|
|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)| |![Image](https://github.com/user-attachments/assets/e74b32a5-5b2e-4c87-9df8-487c0f8366b7)|![Image](https://github.com/user-attachments/assets/bfe8bec2-9e55-493d-9a26-7e9cce28e03d)|![Image](https://github.com/user-attachments/assets/b099dfe3-ff1f-4b96-894c-d48bbe92db7a)|![Image](https://github.com/user-attachments/assets/0a6b2982-deab-4b0d-91ad-888782de01c9)|![Image](https://github.com/user-attachments/assets/fcecb755-7d03-4020-b83a-13ad2b38705c)|
</details> </details>
@@ -787,10 +818,10 @@ DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)| ||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|-|-|-|-|-| |-|-|-|-|-|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)| |[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![Image](https://github.com/user-attachments/assets/01c54d5a-4f00-4c2e-982a-4ec0a4c6a6e3)|![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)| |[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/43720a9f-aa27-4918-947d-545389375d46)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)| |[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/041a3f9a-c7b4-4311-8582-cb71a7226d80)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)| |[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|![Image](https://github.com/user-attachments/assets/a640fd54-3192-49a0-9281-b43d9ba64f09)|
</details> </details>

View File

@@ -3,14 +3,14 @@ from ..vram.disk_map import DiskMap
from ..vram.layers import enable_vram_management from ..vram.layers import enable_vram_management
from .file import load_state_dict from .file import load_state_dict
import torch import torch
from contextlib import contextmanager
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import ContextManagers
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None): def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
config = {} if config is None else config config = {} if config is None else config
# Why do we use `skip_model_initialization`? with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
# It skips the random initialization of model parameters,
# thereby speeding up model loading and avoiding excessive memory usage.
with skip_model_initialization():
model = model_class(**config) model = model_class(**config)
# What is `module_map`? # What is `module_map`?
# This is a module mapping table for VRAM management. # This is a module mapping table for VRAM management.
@@ -48,7 +48,14 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
state_dict = state_dict_converter(state_dict) state_dict = state_dict_converter(state_dict)
else: else:
state_dict = {i: state_dict[i] for i in state_dict} state_dict = {i: state_dict[i] for i in state_dict}
model.load_state_dict(state_dict, assign=True) # Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
# Because at this stage, model parameters are partitioned across multiple GPUs.
# Loading them directly could lead to excessive GPU memory consumption.
if is_deepspeed_zero3_enabled():
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
_load_state_dict_into_zero3_model(model, state_dict)
else:
model.load_state_dict(state_dict, assign=True)
# Why do we call `to()`? # Why do we call `to()`?
# Because some models override the behavior of `to()`, # Because some models override the behavior of `to()`,
# especially those from libraries like Transformers. # especially those from libraries like Transformers.
@@ -79,3 +86,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor
} }
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80) enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
return model return model
def get_init_context(torch_dtype, device):
if is_deepspeed_zero3_enabled():
from transformers.modeling_utils import set_zero3_state
import deepspeed
# Why do we use "deepspeed.zero.Init"?
# Weight segmentation of the model can be performed on the CPU side
# and loading the segmented weights onto the computing card
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
else:
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
# thereby speeding up model loading and avoiding excessive memory usage.
init_contexts = [skip_model_initialization()]
return init_contexts

View File

@@ -0,0 +1,30 @@
import torch
from ..device.npu_compatible_device import get_device_type
try:
import torch_npu
except:
pass
def rms_norm_forward_npu(self, hidden_states):
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
if hidden_states.dtype != self.weight.dtype:
hidden_states = hidden_states.to(self.weight.dtype)
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
def rms_norm_forward_transformers_npu(self, hidden_states):
"npu rms fused operator for transformers"
if hidden_states.dtype != self.weight.dtype:
hidden_states = hidden_states.to(self.weight.dtype)
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
"npu rope fused operator for Zimage"
with torch.amp.autocast(get_device_type(), enabled=False):
freqs_cis = freqs_cis.unsqueeze(2)
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)

View File

@@ -18,8 +18,8 @@ class ModelLogger:
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id): def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process: if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict) state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True) os.makedirs(self.output_path, exist_ok=True)
@@ -34,8 +34,8 @@ class ModelLogger:
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name): def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process: if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict) state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True) os.makedirs(self.output_path, exist_ok=True)

View File

@@ -27,7 +27,7 @@ def launch_training_task(
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch_id in range(num_epochs): for epoch_id in range(num_epochs):
@@ -59,6 +59,7 @@ def launch_data_process_task(
num_workers = args.dataset_num_workers num_workers = args.dataset_num_workers
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, dataloader = accelerator.prepare(model, dataloader) model, dataloader = accelerator.prepare(model, dataloader)
for data_id, data in enumerate(tqdm(dataloader)): for data_id, data in enumerate(tqdm(dataloader)):

View File

@@ -407,6 +407,7 @@ class Flux2AttnProcessor:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
hidden_states = attention_forward( hidden_states = attention_forward(
query, query,
key, key,
@@ -536,6 +537,7 @@ class Flux2ParallelSelfAttnProcessor:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
hidden_states = attention_forward( hidden_states = attention_forward(
query, query,
key, key,

View File

@@ -5,6 +5,7 @@ import math
from typing import Tuple, Optional from typing import Tuple, Optional
from einops import rearrange from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter from .wan_video_camera_controller import SimpleAdapter
from ..core.gradient import gradient_checkpoint_forward
try: try:
import flash_attn_interface import flash_attn_interface
@@ -379,27 +380,15 @@ class WanModel(torch.nn.Module):
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks: for block in self.blocks:
if self.training and use_gradient_checkpointing: if self.training:
if use_gradient_checkpointing_offload: x = gradient_checkpoint_forward(
with torch.autograd.graph.save_on_cpu(): block,
x = torch.utils.checkpoint.checkpoint( use_gradient_checkpointing,
create_custom_forward(block), use_gradient_checkpointing_offload,
x, context, t_mod, freqs, x, context, t_mod, freqs
use_reentrant=False, )
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else: else:
x = block(x, context, t_mod, freqs) x = block(x, context, t_mod, freqs)

View File

@@ -4,6 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Tuple from typing import Tuple
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
from ..core.gradient import gradient_checkpoint_forward
def torch_dfs(model: nn.Module, parent_name='root'): def torch_dfs(model: nn.Module, parent_name='root'):
@@ -545,46 +546,19 @@ class WanS2VModel(torch.nn.Module):
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(self.blocks): for block_id, block in enumerate(self.blocks):
if use_gradient_checkpointing_offload: x = gradient_checkpoint_forward(
with torch.autograd.graph.save_on_cpu(): block,
x = torch.utils.checkpoint.checkpoint( use_gradient_checkpointing,
create_custom_forward(block), use_gradient_checkpointing_offload,
x, x, context, t_mod, seq_len_x, pre_compute_freqs[0]
context, )
t_mod, x = gradient_checkpoint_forward(
seq_len_x, lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
pre_compute_freqs[0], use_gradient_checkpointing,
use_reentrant=False, use_gradient_checkpointing_offload,
) x
x = torch.utils.checkpoint.checkpoint( )
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
x = x[:, :seq_len_x] x = x[:, :seq_len_x]
x = self.head(x, t[:-1]) x = self.head(x, t[:-1])

View File

@@ -1,6 +1,6 @@
import torch import torch
from .wan_video_dit import DiTBlock from .wan_video_dit import DiTBlock
from ..core.gradient import gradient_checkpoint_forward
class VaceWanAttentionBlock(DiTBlock): class VaceWanAttentionBlock(DiTBlock):
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
@@ -62,26 +62,13 @@ class VaceWanModel(torch.nn.Module):
dim=1) for u in c dim=1) for u in c
]) ])
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.vace_blocks: for block in self.vace_blocks:
if use_gradient_checkpointing_offload: c = gradient_checkpoint_forward(
with torch.autograd.graph.save_on_cpu(): block,
c = torch.utils.checkpoint.checkpoint( use_gradient_checkpointing,
create_custom_forward(block), use_gradient_checkpointing_offload,
c, x, context, t_mod, freqs, c, x, context, t_mod, freqs
use_reentrant=False, )
)
elif use_gradient_checkpointing:
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
else:
c = block(c, x, context, t_mod, freqs)
hints = torch.unbind(c)[:-1] hints = torch.unbind(c)[:-1]
return hints return hints

View File

@@ -171,7 +171,7 @@ class Resample(nn.Module):
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
return x return x, feat_cache, feat_idx
def init_weight(self, conv): def init_weight(self, conv):
conv_weight = conv.weight conv_weight = conv.weight
@@ -298,7 +298,7 @@ class ResidualBlock(nn.Module):
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
x = layer(x) x = layer(x)
return x + h return x + h, feat_cache, feat_idx
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
@@ -471,7 +471,7 @@ class Down_ResidualBlock(nn.Module):
for module in self.downsamples: for module in self.downsamples:
x = module(x, feat_cache, feat_idx) x = module(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy) return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
class Up_ResidualBlock(nn.Module): class Up_ResidualBlock(nn.Module):
@@ -511,7 +511,7 @@ class Up_ResidualBlock(nn.Module):
x_shortcut = self.avg_shortcut(x, first_chunk) x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut return x_main + x_shortcut
else: else:
return x_main return x_main, feat_cache, feat_idx
class Encoder3d(nn.Module): class Encoder3d(nn.Module):
@@ -586,14 +586,14 @@ class Encoder3d(nn.Module):
## downsamples ## downsamples
for layer in self.downsamples: for layer in self.downsamples:
if feat_cache is not None: if feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
## middle ## middle
for layer in self.middle: for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None: if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
@@ -614,7 +614,7 @@ class Encoder3d(nn.Module):
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
x = layer(x) x = layer(x)
return x return x, feat_cache, feat_idx
class Encoder3d_38(nn.Module): class Encoder3d_38(nn.Module):
@@ -698,14 +698,14 @@ class Encoder3d_38(nn.Module):
## downsamples ## downsamples
for layer in self.downsamples: for layer in self.downsamples:
if feat_cache is not None: if feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
## middle ## middle
for layer in self.middle: for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None: if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
@@ -730,7 +730,7 @@ class Encoder3d_38(nn.Module):
else: else:
x = layer(x) x = layer(x)
return x return x, feat_cache, feat_idx
class Decoder3d(nn.Module): class Decoder3d(nn.Module):
@@ -807,14 +807,14 @@ class Decoder3d(nn.Module):
## middle ## middle
for layer in self.middle: for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None: if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
## upsamples ## upsamples
for layer in self.upsamples: for layer in self.upsamples:
if feat_cache is not None: if feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
@@ -835,7 +835,7 @@ class Decoder3d(nn.Module):
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
x = layer(x) x = layer(x)
return x return x, feat_cache, feat_idx
@@ -906,14 +906,14 @@ class Decoder3d_38(nn.Module):
for layer in self.middle: for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None: if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
## upsamples ## upsamples
for layer in self.upsamples: for layer in self.upsamples:
if feat_cache is not None: if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, first_chunk) x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk)
else: else:
x = layer(x) x = layer(x)
@@ -937,7 +937,7 @@ class Decoder3d_38(nn.Module):
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
x = layer(x) x = layer(x)
return x return x, feat_cache, feat_idx
def count_conv3d(model): def count_conv3d(model):
@@ -990,11 +990,11 @@ class VideoVAE_(nn.Module):
for i in range(iter_): for i in range(iter_):
self._enc_conv_idx = [0] self._enc_conv_idx = [0]
if i == 0: if i == 0:
out = self.encoder(x[:, :, :1, :, :], out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
feat_cache=self._enc_feat_map, feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx) feat_idx=self._enc_conv_idx)
else: else:
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map, feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx) feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
@@ -1023,11 +1023,11 @@ class VideoVAE_(nn.Module):
for i in range(iter_): for i in range(iter_):
self._conv_idx = [0] self._conv_idx = [0]
if i == 0: if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :], out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map, feat_cache=self._feat_map,
feat_idx=self._conv_idx) feat_idx=self._conv_idx)
else: else:
out_ = self.decoder(x[:, :, i:i + 1, :, :], out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map, feat_cache=self._feat_map,
feat_idx=self._conv_idx) feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) # may add tensor offload out = torch.cat([out, out_], 2) # may add tensor offload
@@ -1303,11 +1303,11 @@ class VideoVAE38_(VideoVAE_):
for i in range(iter_): for i in range(iter_):
self._enc_conv_idx = [0] self._enc_conv_idx = [0]
if i == 0: if i == 0:
out = self.encoder(x[:, :, :1, :, :], out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
feat_cache=self._enc_feat_map, feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx) feat_idx=self._enc_conv_idx)
else: else:
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map, feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx) feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
@@ -1337,12 +1337,12 @@ class VideoVAE38_(VideoVAE_):
for i in range(iter_): for i in range(iter_):
self._conv_idx = [0] self._conv_idx = [0]
if i == 0: if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :], out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map, feat_cache=self._feat_map,
feat_idx=self._conv_idx, feat_idx=self._conv_idx,
first_chunk=True) first_chunk=True)
else: else:
out_ = self.decoder(x[:, :, i:i + 1, :, :], out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map, feat_cache=self._feat_map,
feat_idx=self._conv_idx) feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)

View File

@@ -88,6 +88,14 @@ class Attention(torch.nn.Module):
self.norm_q = RMSNorm(head_dim, eps=1e-5) self.norm_q = RMSNorm(head_dim, eps=1e-5)
self.norm_k = RMSNorm(head_dim, eps=1e-5) self.norm_k = RMSNorm(head_dim, eps=1e-5)
# Apply RoPE
def apply_rotary_emb(self, x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
def forward(self, hidden_states, freqs_cis, attention_mask): def forward(self, hidden_states, freqs_cis, attention_mask):
query = self.to_q(hidden_states) query = self.to_q(hidden_states)
key = self.to_k(hidden_states) key = self.to_k(hidden_states)
@@ -103,17 +111,9 @@ class Attention(torch.nn.Module):
if self.norm_k is not None: if self.norm_k is not None:
key = self.norm_k(key) key = self.norm_k(key)
# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
if freqs_cis is not None: if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis) query = self.apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis) key = self.apply_rotary_emb(key, freqs_cis)
# Cast to correct dtype # Cast to correct dtype
dtype = query.dtype dtype = query.dtype

View File

@@ -348,13 +348,12 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
attention_mask = torch.cat(all_attention_masks, dim=0).to(device) attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model # Forward pass through the model
with torch.inference_mode(): output = text_encoder(
output = text_encoder( input_ids=input_ids,
input_ids=input_ids, attention_mask=attention_mask,
attention_mask=attention_mask, output_hidden_states=True,
output_hidden_states=True, use_cache=False,
use_cache=False, )
)
# Only use outputs from intermediate layers and stack them # Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1) out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)

View File

@@ -1321,11 +1321,6 @@ def model_fn_wan_video(
if tea_cache_update: if tea_cache_update:
x = tea_cache.update(x) x = tea_cache.update(x)
else: else:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
def create_custom_forward_vap(block, vap): def create_custom_forward_vap(block, vap):
def custom_forward(*inputs): def custom_forward(*inputs):
return vap(block, *inputs) return vap(block, *inputs)
@@ -1339,32 +1334,24 @@ def model_fn_wan_video(
x, x_vap = torch.utils.checkpoint.checkpoint( x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap), create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False, use_reentrant=False
) )
elif use_gradient_checkpointing: elif use_gradient_checkpointing:
x, x_vap = torch.utils.checkpoint.checkpoint( x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap), create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False, use_reentrant=False
) )
else: else:
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
else: else:
if use_gradient_checkpointing_offload: x = gradient_checkpoint_forward(
with torch.autograd.graph.save_on_cpu(): block,
x = torch.utils.checkpoint.checkpoint( use_gradient_checkpointing,
create_custom_forward(block), use_gradient_checkpointing_offload,
x, context, t_mod, freqs, x, context, t_mod, freqs
use_reentrant=False, )
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
# VACE # VACE
if vace_context is not None and block_id in vace.vace_layers_mapping: if vace_context is not None and block_id in vace.vace_layers_mapping:
@@ -1487,32 +1474,18 @@ def model_fn_wans2v(
return custom_forward return custom_forward
for block_id, block in enumerate(dit.blocks): for block_id, block in enumerate(dit.blocks):
if use_gradient_checkpointing_offload: x = gradient_checkpoint_forward(
with torch.autograd.graph.save_on_cpu(): block,
x = torch.utils.checkpoint.checkpoint( use_gradient_checkpointing,
create_custom_forward(block), use_gradient_checkpointing_offload,
x, context, t_mod, seq_len_x, pre_compute_freqs[0], x, context, t_mod, seq_len_x, pre_compute_freqs[0]
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
use_reentrant=False,
) )
x = torch.utils.checkpoint.checkpoint( x = gradient_checkpoint_forward(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
x, use_gradient_checkpointing,
use_reentrant=False, use_gradient_checkpointing_offload,
) x
else: )
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel)
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1) x = get_sp_group().all_gather(x, dim=1)

View File

@@ -1,4 +1,4 @@
import torch, math import torch, math, warnings
from PIL import Image from PIL import Image
from typing import Union from typing import Union
from tqdm import tqdm from tqdm import tqdm
@@ -6,7 +6,7 @@ from einops import rearrange
import numpy as np import numpy as np
from typing import Union, List, Optional, Tuple, Iterable, Dict from typing import Union, List, Optional, Tuple, Iterable, Dict
from ..core.device.npu_compatible_device import get_device_type from ..core.device.npu_compatible_device import get_device_type, IS_NPU_AVAILABLE
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize from ..core.data.operators import ImageCropAndResize
@@ -63,6 +63,7 @@ class ZImagePipeline(BasePipeline):
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None, vram_limit: float = None,
enable_npu_patch: bool = True,
): ):
# Initialize pipeline # Initialize pipeline
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype) pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
@@ -84,6 +85,8 @@ class ZImagePipeline(BasePipeline):
# VRAM Management # VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state() pipe.vram_management_enabled = pipe.check_vram_management_state()
# NPU patch
apply_npu_patch(enable_npu_patch)
return pipe return pipe
@@ -667,3 +670,19 @@ def model_fn_z_image_turbo(
x = rearrange(x, "C B H W -> B C H W") x = rearrange(x, "C B H W -> B C H W")
x = -x x = -x
return x return x
def apply_npu_patch(enable_npu_patch: bool=True):
if IS_NPU_AVAILABLE and enable_npu_patch:
from ..models.general_modules import RMSNorm
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
from ..models.z_image_dit import Attention
from ..core.npu_patch.npu_fused_operator import (
rms_norm_forward_npu,
rms_norm_forward_transformers_npu,
rotary_emb_Zimage_npu
)
warnings.warn("Replacing RMSNorm and Rope with NPU fusion operators to improve the performance of the model on NPU.Set enable_npu_patch=False to disable this feature.")
RMSNorm.forward = rms_norm_forward_npu
Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu
Attention.apply_rotary_emb = rotary_emb_Zimage_npu

View File

@@ -9,6 +9,7 @@ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ... import IS_NPU_AVAILABLE from ... import IS_NPU_AVAILABLE
from ...core.device import parse_nccl_backend, parse_device_type from ...core.device import parse_nccl_backend, parse_device_type
from ...core.gradient import gradient_checkpoint_forward
def initialize_usp(device_type): def initialize_usp(device_type):
@@ -87,11 +88,6 @@ def usp_dit_forward(self,
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Context Parallel # Context Parallel
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
@@ -100,20 +96,13 @@ def usp_dit_forward(self,
x = chunks[get_sequence_parallel_rank()] x = chunks[get_sequence_parallel_rank()]
for block in self.blocks: for block in self.blocks:
if self.training and use_gradient_checkpointing: if self.training:
if use_gradient_checkpointing_offload: x = gradient_checkpoint_forward(
with torch.autograd.graph.save_on_cpu(): block,
x = torch.utils.checkpoint.checkpoint( use_gradient_checkpointing,
create_custom_forward(block), use_gradient_checkpointing_offload,
x, context, t_mod, freqs, x, context, t_mod, freqs
use_reentrant=False, )
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else: else:
x = block(x, context, t_mod, freqs) x = block(x, context, t_mod, freqs)

View File

@@ -107,6 +107,11 @@ Special Training Scripts:
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/qwen_image/model_training/special/split_training/) * Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/qwen_image/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) * End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)
DeepSpeed ZeRO Stage 3 Training: The Qwen-Image series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Qwen-Image model as an example, the following modifications are required:
* `--config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml`
* `--initialize_model_on_cpu`
## Model Inference ## Model Inference
Models are loaded via `QwenImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). Models are loaded via `QwenImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).

View File

@@ -142,6 +142,11 @@ graph LR;
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/) * Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/) * End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/)
DeepSpeed ZeRO Stage 3 Training: The Wan series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Wan2.1-T2V-14B model as an example, the following modifications are required:
* `--config_file examples/wanvideo/model_training/full/accelerate_config_zero3.yaml`
* `--initialize_model_on_cpu`
## Model Inference ## Model Inference
Models are loaded via `WanVideoPipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). Models are loaded via `WanVideoPipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).

View File

@@ -89,4 +89,5 @@ Set 0 or not set: indicates not enabling the binding function
#### Parameters for specific models #### Parameters for specific models
| Model | Parameter | Note | | Model | Parameter | Note |
|----------------|---------------------------|-------------------| |----------------|---------------------------|-------------------|
| Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU | | Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU |
| Qwen-Image series | --initialize_model_on_cpu | The model needs to be initialized on the CPU |

View File

@@ -102,4 +102,65 @@ image.save("image.jpg")
Each model `Pipeline` has different input parameters. Please refer to the documentation for each model. Each model `Pipeline` has different input parameters. Please refer to the documentation for each model.
If the model parameters are too large, causing insufficient VRAM, please enable [VRAM management](/docs/en/Pipeline_Usage/VRAM_management.md). If the model parameters are too large, causing insufficient VRAM, please enable [VRAM management](/docs/en/Pipeline_Usage/VRAM_management.md).
## Loading LoRA
LoRA is a lightweight model training method that produces a small number of parameters to extend model capabilities. DiffSynth-Studio supports two ways to load LoRA: cold loading and hot loading.
* Cold loading: When the base model does not have [VRAM management](/docs/en/Pipeline_Usage/VRAM_management.md) enabled, LoRA will be fused into the base model weights. In this case, inference speed remains unchanged, but LoRA cannot be unloaded after loading.
```python
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
lora = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1", origin_file_pattern="model.safetensors")
pipe.load_lora(pipe.dit, lora, alpha=1)
prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal."
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")
```
* Hot loading: When the base model has [VRAM management](/docs/en/Pipeline_Usage/VRAM_management.md) enabled, LoRA will not be fused into the base model weights. In this case, inference speed will be slower, but LoRA can be unloaded through `pipe.clear_lora()` after loading.
```python
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cuda",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
lora = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1", origin_file_pattern="model.safetensors")
pipe.load_lora(pipe.dit, lora, alpha=1)
prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal."
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")
pipe.clear_lora()
```

View File

@@ -25,4 +25,11 @@ Even with suitable hardware conditions, we currently have no plans to support na
* The main challenge of native FP8 precision training is precision overflow caused by gradient explosion. To ensure training stability, the model structure needs to be redesigned accordingly. However, no model developers are willing to do so at present. * The main challenge of native FP8 precision training is precision overflow caused by gradient explosion. To ensure training stability, the model structure needs to be redesigned accordingly. However, no model developers are willing to do so at present.
* Additionally, models trained with native FP8 precision can only be computed with BF16 precision during inference without Hopper architecture GPUs, theoretically resulting in generation quality inferior to FP8. * Additionally, models trained with native FP8 precision can only be computed with BF16 precision during inference without Hopper architecture GPUs, theoretically resulting in generation quality inferior to FP8.
Therefore, native FP8 precision training technology is extremely immature. We will observe the technological developments in the open-source community. Therefore, native FP8 precision training technology is extremely immature. We will observe the technological developments in the open-source community.
## How to dynamically load LoRA models during inference?
We support two loading methods for LoRA models. See [LoRA Loading](/docs/en/Pipeline_Usage/Model_Inference.md#loading-lora) for details:
* Cold Loading: When [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) is not enabled for the base model, LoRA will be fused into the base model weights. In this case, inference speed remains unchanged, and LoRA cannot be unloaded after loading.
* Hot Loading: When [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) is enabled for the base model, LoRA will not be fused into the base model weights. In this case, inference speed will slow down, and LoRA can be unloaded after loading via `pipe.clear_lora()`.

View File

@@ -107,6 +107,11 @@ graph LR;
* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/qwen_image/model_training/special/split_training/) * 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/qwen_image/model_training/special/split_training/)
* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) * 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)
DeepSpeed ZeRO 3 训练Qwen-Image 系列模型支持 DeepSpeed ZeRO 3 训练,将模型拆分到多个 GPU 上,以 Qwen-Image 模型的全量训练为例,需修改:
* `--config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml`
* `--initialize_model_on_cpu`
## 模型推理 ## 模型推理
模型通过 `QwenImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 模型通过 `QwenImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。

View File

@@ -143,6 +143,11 @@ graph LR;
* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/wanvideo/model_training/special/split_training/) * 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/wanvideo/model_training/special/split_training/)
* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/wanvideo/model_training/special/direct_distill/) * 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/wanvideo/model_training/special/direct_distill/)
DeepSpeed ZeRO 3 训练Wan 系列模型支持 DeepSpeed ZeRO 3 训练,将模型拆分到多个 GPU 上,以 Wan2.1-T2V-14B 模型的全量训练为例,需修改:
* `--config_file examples/wanvideo/model_training/full/accelerate_config_zero3.yaml`
* `--initialize_model_on_cpu`
## 模型推理 ## 模型推理
模型通过 `WanVideoPipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 模型通过 `WanVideoPipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。

View File

@@ -88,4 +88,5 @@ export CPU_AFFINITY_CONF=1
#### 特定模型需要开启的参数 #### 特定模型需要开启的参数
| 模型 | 参数 | 备注 | | 模型 | 参数 | 备注 |
|-----------|------|-------------------| |-----------|------|-------------------|
| Wan 14B系列 | --initialize_model_on_cpu | 14B模型需要在cpu上进行初始化 | | Wan 14B系列 | --initialize_model_on_cpu | 14B模型需要在cpu上进行初始化 |
| Qwen-Image系列 | --initialize_model_on_cpu | 模型需要在cpu上进行初始化 |

View File

@@ -103,3 +103,64 @@ image.save("image.jpg")
每个模型 `Pipeline` 的输入参数不同,请参考各模型的文档。 每个模型 `Pipeline` 的输入参数不同,请参考各模型的文档。
如果模型参数量太大,导致显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。 如果模型参数量太大,导致显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。
## 加载 LoRA
LoRA 是一种轻量化的模型训练方式产生少量参数扩展模型的能力。DiffSynth-Studio 的 LoRA 加载有两种方式:冷加载和热加载。
* 冷加载:当基础模型未开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)时LoRA 会融合进基础模型权重此时推理速度没有变化LoRA 加载后无法卸载。
```python
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
lora = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1", origin_file_pattern="model.safetensors")
pipe.load_lora(pipe.dit, lora, alpha=1)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")
```
* 热加载:当基础模型开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)时LoRA 不会融合进基础模型权重此时推理速度会变慢LoRA 加载后可通过 `pipe.clear_lora()` 卸载。
```python
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cuda",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
lora = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1", origin_file_pattern="model.safetensors")
pipe.load_lora(pipe.dit, lora, alpha=1)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")
pipe.clear_lora()
```

View File

@@ -26,3 +26,10 @@
* 此外,使用原生 FP8 精度训练的模型,在推理时若没有 Hopper 架构 GPU则只能以 BF16 精度进行计算,理论上其生成效果反而不如 FP8。 * 此外,使用原生 FP8 精度训练的模型,在推理时若没有 Hopper 架构 GPU则只能以 BF16 精度进行计算,理论上其生成效果反而不如 FP8。
因此,原生 FP8 精度训练技术是极不成熟的,我们静观开源社区的技术发展。 因此,原生 FP8 精度训练技术是极不成熟的,我们静观开源社区的技术发展。
## 如何在推理时动态加载 LoRA 模型?
我们支持 LoRA 模型的两种加载方式,详见[LoRA 加载](/docs/zh/Pipeline_Usage/Model_Inference.md#加载-lora)
* 冷加载:当基础模型未开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)时LoRA 会融合进基础模型权重此时推理速度没有变化LoRA 加载后无法卸载。
* 热加载:当基础模型开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)时LoRA 不会融合进基础模型权重此时推理速度会变慢LoRA 加载后可通过 `pipe.clear_lora()` 卸载。

View File

@@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,36 @@
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 1 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors,black-forest-labs/FLUX.2-dev:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--task "sft:data_process"
accelerate launch --config_file examples/flux2/model_training/full/accelerate_config_zero3.yaml examples/flux2/model_training/train.py \
--dataset_base_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:transformer/*.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-dev-LoRA-splited" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--initialize_model_on_cpu \
--task "sft:train"

View File

@@ -0,0 +1,34 @@
# This script is tested on 8*910B(NPU)
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-9B_full" \
--trainable_models "dit" \
--use_gradient_checkpointing
# Edit
# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
# --learning_rate 1e-5 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-9B_full" \
# --trainable_models "dit" \
# --use_gradient_checkpointing

View File

@@ -85,6 +85,7 @@ def flux2_parser():
parser = add_general_config(parser) parser = add_general_config(parser)
parser = add_image_size_config(parser) parser = add_image_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
return parser return parser
@@ -126,7 +127,7 @@ if __name__ == "__main__":
fp8_models=args.fp8_models, fp8_models=args.fp8_models,
offload_models=args.offload_models, offload_models=args.offload_models,
task=args.task, task=args.task,
device=accelerator.device, device="cpu" if args.initialize_model_on_cpu else accelerator.device,
) )
model_logger = ModelLogger( model_logger = ModelLogger(
args.output_path, args.output_path,

View File

@@ -0,0 +1,47 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import snapshot_download
from PIL import Image
import torch
# Load models
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
lora = ModelConfig(
model_id="DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA",
origin_file_pattern="model.safetensors"
)
pipe.load_lora(pipe.dit, lora)
# Load images
snapshot_download(
"DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA",
local_dir="./data",
allow_file_pattern="assets/*"
)
edit_image = [
Image.open("data/assets/image1_original.png"),
Image.open("data/assets/image1_edit_1.png"),
Image.open("data/assets/image2_original.png")
]
prompt = "Edit image 3 based on the transformation from image 1 to image 2."
negative_prompt = "泛黄AI感不真实丑陋油腻的皮肤异常的肢体不协调的肢体"
# Generate
image_4 = pipe(
prompt=prompt, negative_prompt=negative_prompt,
edit_image=edit_image,
seed=1,
num_inference_steps=50,
height=1280,
width=720,
zero_cond_t=True,
)
image_4.save("image.png")

View File

@@ -0,0 +1,58 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import snapshot_download
from PIL import Image
import torch
# Load models
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
lora = ModelConfig(
model_id="DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA",
origin_file_pattern="model.safetensors"
)
pipe.load_lora(pipe.dit, lora)
# Load images
snapshot_download(
"DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA",
local_dir="./data",
allow_file_pattern="assets/*"
)
edit_image = [
Image.open("data/assets/image1_original.png"),
Image.open("data/assets/image1_edit_1.png"),
Image.open("data/assets/image2_original.png")
]
prompt = "Edit image 3 based on the transformation from image 1 to image 2."
negative_prompt = "泛黄AI感不真实丑陋油腻的皮肤异常的肢体不协调的肢体"
# Generate
image_4 = pipe(
prompt=prompt, negative_prompt=negative_prompt,
edit_image=edit_image,
seed=1,
num_inference_steps=50,
height=1280,
width=720,
zero_cond_t=True,
)
image_4.save("image.png")

View File

@@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,20 @@
# This script was tested using zero3 and on 8*910B(NPU)
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Edit-2509_full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--initialize_model_on_cpu

View File

@@ -101,6 +101,7 @@ def qwen_image_parser():
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
parser.add_argument("--zero_cond_t", default=False, action="store_true", help="A special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.") parser.add_argument("--zero_cond_t", default=False, action="store_true", help="A special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.")
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
return parser return parser
@@ -151,7 +152,7 @@ if __name__ == "__main__":
fp8_models=args.fp8_models, fp8_models=args.fp8_models,
offload_models=args.offload_models, offload_models=args.offload_models,
task=args.task, task=args.task,
device=accelerator.device, device="cpu" if args.initialize_model_on_cpu else accelerator.device,
zero_cond_t=args.zero_cond_t, zero_cond_t=args.zero_cond_t,
) )
model_logger = ModelLogger( model_logger = ModelLogger(

View File

@@ -7,6 +7,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
--num_frames 81 \ --num_frames 81 \
--dataset_repeat 100 \ --dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \ --model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \
--audio_processor_path "Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/" \
--learning_rate 1e-5 \ --learning_rate 1e-5 \
--num_epochs 1 \ --num_epochs 1 \
--trainable_models "dit" \ --trainable_models "dit" \

View File

@@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -7,6 +7,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
--num_frames 81 \ --num_frames 81 \
--dataset_repeat 100 \ --dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \ --model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \
--audio_processor_path "Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/" \
--learning_rate 1e-4 \ --learning_rate 1e-4 \
--num_epochs 5 \ --num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \ --remove_prefix_in_ckpt "pipe.dit." \

View File

@@ -33,7 +33,7 @@ class WanTrainingModule(DiffusionTrainingModule):
# Load models # Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/") if tokenizer_path is None else ModelConfig(tokenizer_path) tokenizer_config = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/") if tokenizer_path is None else ModelConfig(tokenizer_path)
audio_processor_config = ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/") if audio_processor_path is None else ModelConfig(audio_processor_path) audio_processor_config = self.parse_path_or_model_id(audio_processor_path)
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, audio_processor_config=audio_processor_config) self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, audio_processor_config=audio_processor_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)

View File

@@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -13,4 +13,5 @@ accelerate launch --config_file examples/z_image/model_training/full/accelerate_
--output_path "./models/train/Z-Image-Turbo_full" \ --output_path "./models/train/Z-Image-Turbo_full" \
--trainable_models "dit" \ --trainable_models "dit" \
--use_gradient_checkpointing \ --use_gradient_checkpointing \
--dataset_num_workers 8 --dataset_num_workers 8 \
--enable_npu_patch

View File

@@ -20,12 +20,13 @@ class ZImageTrainingModule(DiffusionTrainingModule):
offload_models=None, offload_models=None,
device="cpu", device="cpu",
task="sft", task="sft",
enable_npu_patch=True,
): ):
super().__init__() super().__init__()
# Load models # Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) tokenizer_config = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, enable_npu_patch=enable_npu_patch)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# Training mode # Training mode
@@ -94,6 +95,7 @@ def z_image_parser():
parser = add_general_config(parser) parser = add_general_config(parser)
parser = add_image_size_config(parser) parser = add_image_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--enable_npu_patch", default=False, action="store_true", help="Whether to use npu fused operator patch to improve performance in NPU.")
return parser return parser
@@ -136,6 +138,7 @@ if __name__ == "__main__":
offload_models=args.offload_models, offload_models=args.offload_models,
task=args.task, task=args.task,
device=accelerator.device, device=accelerator.device,
enable_npu_patch=args.enable_npu_patch
) )
model_logger = ModelLogger( model_logger = ModelLogger(
args.output_path, args.output_path,