mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge branch 'main' into wanvideo_seq_usp
This commit is contained in:
BIN
.github/workflows/logo.gif
vendored
Normal file
BIN
.github/workflows/logo.gif
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 146 KiB |
490
README.md
490
README.md
@@ -1,52 +1,343 @@
|
||||
# DiffSynth Studio
|
||||
# DiffSynth-Studio
|
||||
|
||||
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
|
||||
|
||||
[](https://pypi.org/project/DiffSynth/)
|
||||
[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
|
||||
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
|
||||
Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
|
||||
[切换到中文](./README_zh.md)
|
||||
|
||||
## Introduction
|
||||
|
||||
Welcome to the magic world of Diffusion models!
|
||||
Welcome to the magic world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by [ModelScope](https://www.modelscope.cn/) team. We aim to foster technical innovation through framework development, bring together the power of the open-source community, and explore the limits of generative models!
|
||||
|
||||
DiffSynth consists of two open-source projects:
|
||||
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technological exploration. Targeted at academia. Provides more cutting-edge technical support and novel inference capabilities.
|
||||
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
DiffSynth currently includes two open-source projects:
|
||||
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technical exploration, for academia, providing support for more cutting-edge model capabilities.
|
||||
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment, for industry, offering higher computing performance and more stable features.
|
||||
|
||||
DiffSynth-Studio is an open-source project aimed at exploring innovations in AIGC technology. We have integrated numerous open-source Diffusion models, including FLUX and Wan, among others. Through this open-source project, we hope to connect models within the open-source community and explore new technologies based on diffusion models.
|
||||
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) and [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) are the core projects behind ModelScope [AIGC zone](https://modelscope.cn/aigc/home), offering powerful AI content generation abilities. Come and try our carefully designed features and start your AI creation journey!
|
||||
|
||||
Until now, DiffSynth-Studio has supported the following models:
|
||||
## Installation
|
||||
|
||||
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
||||
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
|
||||
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
|
||||
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
|
||||
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
||||
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
||||
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
||||
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
||||
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
||||
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
Install from source (recommended):
|
||||
|
||||
```
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Other installation methods</summary>
|
||||
|
||||
Install from PyPI (version updates may be delayed; for latest features, install from source)
|
||||
|
||||
```
|
||||
pip install diffsynth
|
||||
```
|
||||
|
||||
If you meet problems during installation, they might be caused by upstream dependencies. Please check the docs of these packages:
|
||||
|
||||
* [torch](https://pytorch.org/get-started/locally/)
|
||||
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||
* [cmake](https://cmake.org)
|
||||
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
||||
|
||||
</details>
|
||||
|
||||
## Basic Framework
|
||||
|
||||
DiffSynth-Studio redesigns the inference and training pipelines for mainstream Diffusion models (including FLUX, Wan, etc.), enabling efficient memory management and flexible model training.
|
||||
|
||||
### FLUX Series
|
||||
|
||||
Detail page: [./examples/flux/](./examples/flux/)
|
||||
|
||||

|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Model Overview</summary>
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||
|[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||
|[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||
|[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||
|[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||
|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||
|[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||
|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||
|[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
||||
|[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
||||
|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](./examples/flux/model_training/lora/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
### Wan Series
|
||||
|
||||
Detail page: [./examples/wanvideo/](./examples/wanvideo/)
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = pipe(
|
||||
prompt="A documentary photography style scene: a lively puppy rapidly running on green grass. The puppy has brown-yellow fur, upright ears, and looks focused and joyful. Sunlight shines on its body, making the fur appear soft and shiny. The background is an open field with occasional wildflowers, and faint blue sky and clouds in the distance. Strong sense of perspective captures the motion of the puppy and the vitality of the surrounding grass. Mid-shot side-moving view.",
|
||||
negative_prompt="Bright colors, overexposed, static, blurry details, subtitles, style, artwork, image, still, overall gray, worst quality, low quality, JPEG compression artifacts, ugly, deformed, extra fingers, poorly drawn hands, poorly drawn face, malformed limbs, fused fingers, still frame, messy background, three legs, crowded background people, walking backwards",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Model Overview</summary>
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/examples/wanmodel_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|
||||
</details>
|
||||
|
||||
### More Models
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Image Generation Models</summary>
|
||||
|
||||
Detail page: [./examples/image_synthesis/](./examples/image_synthesis/)
|
||||
|
||||
|FLUX|Stable Diffusion 3|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
|Kolors|Hunyuan-DiT|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
|Stable Diffusion|Stable Diffusion XL|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Video Generation Models</summary>
|
||||
|
||||
- HunyuanVideo: [./examples/HunyuanVideo/](./examples/HunyuanVideo/)
|
||||
|
||||
https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9
|
||||
|
||||
- StepVideo: [./examples/stepvideo/](./examples/stepvideo/)
|
||||
|
||||
https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b
|
||||
|
||||
- CogVideoX: [./examples/CogVideoX/](./examples/CogVideoX/)
|
||||
|
||||
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Image Quality Assessment Models</summary>
|
||||
|
||||
We have integrated a series of image quality assessment models. These models can be used for evaluating image generation models, alignment training, and similar tasks.
|
||||
|
||||
Detail page: [./examples/image_quality_metric/](./examples/image_quality_metric/)
|
||||
|
||||
* [ImageReward](https://github.com/THUDM/ImageReward)
|
||||
* [Aesthetic](https://github.com/christophschuhmann/improved-aesthetic-predictor)
|
||||
* [PickScore](https://github.com/yuvalkirstain/pickscore)
|
||||
* [CLIP](https://github.com/openai/CLIP)
|
||||
* [HPSv2](https://github.com/tgxs002/HPSv2)
|
||||
* [HPSv2.1](https://github.com/tgxs002/HPSv2)
|
||||
* [MPS](https://github.com/Kwai-Kolors/MPS)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## Innovative Achievements
|
||||
|
||||
DiffSynth-Studio is not just an engineering model framework, but also a platform for incubating innovative results.
|
||||
|
||||
<details>
|
||||
<summary>Nexus-Gen: Unified Architecture for Image Understanding, Generation, and Editing</summary>
|
||||
|
||||
- Detail page: https://github.com/modelscope/Nexus-Gen
|
||||
- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||
- Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||
- Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ArtAug: Aesthetic Enhancement for Image Generation Models</summary>
|
||||
|
||||
- Detail page: [./examples/ArtAug/](./examples/ArtAug/)
|
||||
- Paper: [ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||
- Online Demo: [ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|
||||
|
||||
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>EliGen: Precise Image Region Control</summary>
|
||||
|
||||
- Detail page: [./examples/EntityControl/](./examples/EntityControl/)
|
||||
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
- Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
|Entity Control Mask|Generated Image|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ExVideo: Extended Training for Video Generation Models</summary>
|
||||
|
||||
- Project Page: [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||
- Paper: [ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
|
||||
- Code Example: [./examples/ExVideo/](./examples/ExVideo/)
|
||||
- Model: [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
|
||||
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Diffutoon: High-Resolution Anime-Style Video Rendering</summary>
|
||||
|
||||
- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||
- Paper: [Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
|
||||
- Code Example: [./examples/Diffutoon/](./examples/Diffutoon/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>DiffSynth: The Initial Version of This Project</summary>
|
||||
|
||||
- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
|
||||
- Paper: [DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
|
||||
- Code Example: [./examples/diffsynth/](./examples/diffsynth/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## Update History
|
||||
|
||||
- **July 28, 2025** 🔥🔥🔥 With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||
|
||||
- **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks.
|
||||
- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||
- Github Repo: https://github.com/modelscope/Nexus-Gen
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||
- Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||
- Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||
|
||||
## News
|
||||
- **June 15, 2025** ModelScope's official evaluation framework, [EvalScope](https://github.com/modelscope/evalscope), now supports text-to-image generation evaluation. Try it with the [Best Practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide.
|
||||
|
||||
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
|
||||
- **March 25, 2025** Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
|
||||
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
<details>
|
||||
<summary>More</summary>
|
||||
|
||||
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
|
||||
|
||||
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||
|
||||
@@ -123,135 +414,4 @@ Until now, DiffSynth-Studio has supported the following models:
|
||||
- The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
||||
- The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
Install from source code (recommended):
|
||||
|
||||
```
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Or install from pypi (There is a delay in the update. If you want to experience the latest features, please do not use this installation method.):
|
||||
|
||||
```
|
||||
pip install diffsynth
|
||||
```
|
||||
|
||||
If you encounter issues during installation, it may be caused by the packages we depend on. Please refer to the documentation of the package that caused the problem.
|
||||
|
||||
* [torch](https://pytorch.org/get-started/locally/)
|
||||
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||
* [cmake](https://cmake.org)
|
||||
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
||||
|
||||
## Usage (in Python code)
|
||||
|
||||
The Python examples are in [`examples`](./examples/). We provide an overview here.
|
||||
|
||||
### Download Models
|
||||
|
||||
Download the pre-set models. Model IDs can be found in [config file](/diffsynth/configs/model_config.py).
|
||||
|
||||
```python
|
||||
from diffsynth import download_models
|
||||
|
||||
download_models(["FLUX.1-dev", "Kolors"])
|
||||
```
|
||||
|
||||
Download your own models.
|
||||
|
||||
```python
|
||||
from diffsynth.models.downloader import download_from_huggingface, download_from_modelscope
|
||||
|
||||
# From Modelscope (recommended)
|
||||
download_from_modelscope("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.bin", "models/kolors/Kolors/vae")
|
||||
# From Huggingface
|
||||
download_from_huggingface("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.safetensors", "models/kolors/Kolors/vae")
|
||||
```
|
||||
|
||||
### Video Synthesis
|
||||
|
||||
#### Text-to-video using CogVideoX-5B
|
||||
|
||||
CogVideoX-5B is released by ZhiPu. We provide an improved pipeline, supporting text-to-video, video editing, self-upscaling and video interpolation. [`examples/video_synthesis`](./examples/video_synthesis/)
|
||||
|
||||
The video on the left is generated using the original text-to-video pipeline, while the video on the right is the result after editing and frame interpolation.
|
||||
|
||||
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
|
||||
|
||||
#### Long Video Synthesis
|
||||
|
||||
We trained extended video synthesis models, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
|
||||
|
||||
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||
|
||||
https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e
|
||||
|
||||
#### Toon Shading
|
||||
|
||||
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
|
||||
|
||||
#### Video Stylization
|
||||
|
||||
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
### Image Synthesis
|
||||
|
||||
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
|
||||
|
||||
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|
||||
|
||||
|FLUX|Stable Diffusion 3|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
|Kolors|Hunyuan-DiT|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
|Stable Diffusion|Stable Diffusion XL|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
## Usage (in WebUI)
|
||||
|
||||
Create stunning images using the painter, with assistance from AI!
|
||||
|
||||
https://github.com/user-attachments/assets/95265d21-cdd6-4125-a7cb-9fbcf6ceb7b0
|
||||
|
||||
**This video is not rendered in real-time.**
|
||||
|
||||
Before launching the WebUI, please download models to the folder `./models`. See [here](#download-models).
|
||||
|
||||
* `Gradio` version
|
||||
|
||||
```
|
||||
pip install gradio
|
||||
```
|
||||
|
||||
```
|
||||
python apps/gradio/DiffSynth_Studio.py
|
||||
```
|
||||
|
||||

|
||||
|
||||
* `Streamlit` version
|
||||
|
||||
```
|
||||
pip install streamlit streamlit-drawable-canvas
|
||||
```
|
||||
|
||||
```
|
||||
python -m streamlit run apps/streamlit/DiffSynth_Studio.py
|
||||
```
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
|
||||
</details>
|
||||
433
README_zh.md
Normal file
433
README_zh.md
Normal file
@@ -0,0 +1,433 @@
|
||||
# DiffSynth-Studio
|
||||
|
||||
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
|
||||
|
||||
[](https://pypi.org/project/DiffSynth/)
|
||||
[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
|
||||
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||
|
||||
[Switch to English](./README.md)
|
||||
|
||||
## 简介
|
||||
|
||||
欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
|
||||
|
||||
DiffSynth 目前包括两个开源项目:
|
||||
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索,面向学术界,提供更前沿的模型能力支持。
|
||||
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署,面向工业界,提供更高的计算性能与更稳定的功能。
|
||||
|
||||
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 作为魔搭社区 [AIGC 专区](https://modelscope.cn/aigc/home) 的核心技术支撑,提供了强大的AI生成内容能力。欢迎体验我们精心打造的产品化功能,开启您的AI创作之旅!
|
||||
|
||||
## 安装
|
||||
|
||||
从源码安装(推荐):
|
||||
|
||||
```
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>其他安装方式</summary>
|
||||
|
||||
从 pypi 安装(存在版本更新延迟,如需使用最新功能,请从源码安装)
|
||||
|
||||
```
|
||||
pip install diffsynth
|
||||
```
|
||||
|
||||
如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档:
|
||||
|
||||
* [torch](https://pytorch.org/get-started/locally/)
|
||||
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||
* [cmake](https://cmake.org)
|
||||
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## 基础框架
|
||||
|
||||
DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设计了推理和训练流水线,能够实现高效的显存管理、灵活的模型训练。
|
||||
|
||||
### FLUX 系列
|
||||
|
||||
详细页面:[./examples/flux/](./examples/flux/)
|
||||
|
||||

|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>模型总览</summary>
|
||||
|
||||
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||
|[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||
|[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||
|[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||
|[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||
|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||
|[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||
|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||
|[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
||||
|[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
||||
|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](./examples/flux/model_training/lora/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
||||
|
||||
</details>
|
||||
|
||||
### Wan 系列
|
||||
|
||||
详细页面:[./examples/wanvideo/](./examples/wanvideo/)
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>模型总览</summary>
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/examples/wanmodel_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
### 更多模型
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>图像生成模型</summary>
|
||||
|
||||
详细页面:[./examples/image_synthesis/](./examples/image_synthesis/)
|
||||
|
||||
|FLUX|Stable Diffusion 3|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
|Kolors|Hunyuan-DiT|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
|Stable Diffusion|Stable Diffusion XL|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>视频生成模型</summary>
|
||||
|
||||
- HunyuanVideo:[./examples/HunyuanVideo/](./examples/HunyuanVideo/)
|
||||
|
||||
https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9
|
||||
|
||||
- StepVideo:[./examples/stepvideo/](./examples/stepvideo/)
|
||||
|
||||
https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b
|
||||
|
||||
- CogVideoX:[./examples/CogVideoX/](./examples/CogVideoX/)
|
||||
|
||||
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>图像质量评估模型</summary>
|
||||
|
||||
我们集成了一系列图像质量评估模型,这些模型可以用于图像生成模型的评测、对齐训练等场景中。
|
||||
|
||||
详细页面:[./examples/image_quality_metric/](./examples/image_quality_metric/)
|
||||
|
||||
* [ImageReward](https://github.com/THUDM/ImageReward)
|
||||
* [Aesthetic](https://github.com/christophschuhmann/improved-aesthetic-predictor)
|
||||
* [PickScore](https://github.com/yuvalkirstain/pickscore)
|
||||
* [CLIP](https://github.com/openai/CLIP)
|
||||
* [HPSv2](https://github.com/tgxs002/HPSv2)
|
||||
* [HPSv2.1](https://github.com/tgxs002/HPSv2)
|
||||
* [MPS](https://github.com/Kwai-Kolors/MPS)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## 创新成果
|
||||
|
||||
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
|
||||
|
||||
<details>
|
||||
<summary>Nexus-Gen: 统一架构的图像理解、生成、编辑</summary>
|
||||
|
||||
- 详细页面:https://github.com/modelscope/Nexus-Gen
|
||||
- 论文:[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||
- 数据集:[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||
- 在线体验:[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>ArtAug: 图像生成模型的美学提升</summary>
|
||||
|
||||
- 详细页面:[./examples/ArtAug/](./examples/ArtAug/)
|
||||
- 论文:[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
|
||||
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||
- 在线体验:[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|
||||
|
||||
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>EliGen: 精准的图像分区控制</summary>
|
||||
|
||||
- 详细页面:[./examples/EntityControl/](./examples/EntityControl/)
|
||||
- 论文:[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||
- 在线体验:[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
- 数据集:[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
|实体控制区域|生成图像|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>ExVideo: 视频生成模型的扩展训练</summary>
|
||||
|
||||
- 项目页面:[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||
- 论文:[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
|
||||
- 代码样例:[./examples/ExVideo/](./examples/ExVideo/)
|
||||
- 模型:[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
|
||||
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Diffutoon: 高分辨率动漫风格视频渲染</summary>
|
||||
|
||||
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||
- 论文:[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
|
||||
- 代码样例:[./examples/Diffutoon/](./examples/Diffutoon/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>DiffSynth: 本项目的初代版本</summary>
|
||||
|
||||
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
|
||||
- 论文:[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
|
||||
- 代码样例:[./examples/diffsynth/](./examples/diffsynth/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## 更新历史
|
||||
|
||||
- **2025年7月28日** 🔥🔥🔥 Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
|
||||
|
||||
- **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。
|
||||
- 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||
- Github 仓库: https://github.com/modelscope/Nexus-Gen
|
||||
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||
- 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||
- 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||
|
||||
- **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。
|
||||
|
||||
- **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。
|
||||
|
||||
<details>
|
||||
<summary>更多</summary>
|
||||
|
||||
- **2025年3月31日** 我们支持 InfiniteYou,一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。
|
||||
|
||||
- **2025年3月13日** 我们支持 HunyuanVideo-I2V,即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
|
||||
|
||||
- **2025年2月25日** 我们支持 Wan-Video,这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。
|
||||
|
||||
- **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)!先进的视频合成模型!详见 [./examples/stepvideo](./examples/stepvideo/)。
|
||||
|
||||
- **2024年12月31日** 我们提出 EliGen,一种用于精确实体级别控制的文本到图像生成的新框架,并辅以修复融合管道,将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型,如 IP-Adapter 和 In-Context LoRA,提升其通用性。更多详情,请见 [./examples/EntityControl](./examples/EntityControl/)。
|
||||
- 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||
- 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
- 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
- **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理,使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频,或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
|
||||
|
||||
- **2024年12月18日** 我们提出 ArtAug,一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学理解融入 FLUX.1-dev,从而提升了生成图像的质量。
|
||||
- 论文: https://arxiv.org/abs/2412.12888
|
||||
- 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
|
||||
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||
- 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线)
|
||||
|
||||
- **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型,并且可以自由组合,即使它们的结构不同。此外,ControlNet 模型兼容高分辨率优化和分区控制技术,能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。
|
||||
|
||||
- **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。
|
||||
|
||||
- **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能,包括:
|
||||
- 文本到视频
|
||||
- 视频编辑
|
||||
- 自我超分
|
||||
- 视频插帧
|
||||
|
||||
- **2024年8月22日** 我们实现了一个有趣的画笔功能,支持所有文生图模型。现在,您可以在 AI 的辅助下使用画笔创作惊艳的图像了!
|
||||
- 在我们的 [WebUI](#usage-in-webui) 中使用它。
|
||||
|
||||
- **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。
|
||||
- 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md)
|
||||
- LoRA、ControlNet 和其他附加模型将很快推出。
|
||||
|
||||
- **2024年6月21日** 我们提出 ExVideo,一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展,实现了长达 128 帧的长视频生成。
|
||||
- [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||
- 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。
|
||||
- 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。
|
||||
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。
|
||||
- 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo!
|
||||
|
||||
- **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然,我仍会参与后续的开发和维护工作。
|
||||
|
||||
- **2024年1月29日** 我们提出 Diffutoon,这是一个出色的卡通着色解决方案。
|
||||
- [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||
- 源代码已在此项目中发布。
|
||||
- 技术报告(IJCAI 2024)已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。
|
||||
|
||||
- **2023年12月8日** 我们决定启动一个新项目,旨在释放扩散模型的潜力,尤其是在视频合成方面。该项目的开发工作正式开始。
|
||||
|
||||
- **2023年11月15日** 我们提出 FastBlend,一种强大的视频去闪烁算法。
|
||||
- sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。
|
||||
- 演示视频已在 Bilibili 上展示,包含三个任务:
|
||||
- [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||
- [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||
- [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。
|
||||
- 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。
|
||||
|
||||
- **2023年10月1日** 我们发布了该项目的早期版本,名为 FastSDXL。这是构建一个扩散引擎的初步尝试。
|
||||
- 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。
|
||||
- FastSDXL 包含一个可训练的 OLSS 调度器,以提高效率。
|
||||
- OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。
|
||||
- 技术报告(CIKM 2023)已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。
|
||||
- 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。
|
||||
- 由于 OLSS 需要额外训练,我们未在本项目中实现它。
|
||||
|
||||
- **2023年8月29日** 我们提出 DiffSynth,一个视频合成框架。
|
||||
- [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。
|
||||
- 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。
|
||||
- 技术报告(ECML PKDD 2024)已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。
|
||||
|
||||
</details>
|
||||
@@ -58,14 +58,19 @@ from ..models.stepvideo_dit import StepVideoModel
|
||||
from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
|
||||
from ..lora.flux_lora import FluxLoraPatcher
|
||||
from ..models.flux_value_control import SingleValueEncoder
|
||||
|
||||
from ..lora.flux_lora import FluxLoraPatcher
|
||||
from ..models.flux_lora_encoder import FluxLoRAEncoder
|
||||
|
||||
from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
|
||||
from ..models.nexus_gen import NexusGenAutoregressiveModel
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -104,6 +109,7 @@ model_loader_configs = [
|
||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||
(None, "0629116fce1472503a66992f96f3eb1a", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||
@@ -137,6 +143,8 @@ model_loader_configs = [
|
||||
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
@@ -144,9 +152,14 @@ model_loader_configs = [
|
||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
(None, "e1de6c02cdac79f8b739f4d3698cd216", ["wan_video_vae"], [WanVideoVAE38], "civitai"),
|
||||
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
||||
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
|
||||
(None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"),
|
||||
(None, "77c2e4dd2440269eb33bfaa0d004f6ab", ["flux_lora_encoder"], [FluxLoRAEncoder], "civitai"),
|
||||
(None, "3e6c61b0f9471135fc9c6d6a98e98b6d", ["flux_dit", "nexus_gen_generation_adapter"], [FluxDiT, NexusGenAdapter], "civitai"),
|
||||
(None, "63c969fd37cce769a90aa781fbff5f81", ["flux_dit", "nexus_gen_editing_adapter"], [FluxDiT, NexusGenImageEmbeddingMerger], "civitai"),
|
||||
(None, "2bd19e845116e4f875a0a048e27fc219", ["nexus_gen_llm"], [NexusGenAutoregressiveModel], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
@@ -1,18 +1,58 @@
|
||||
import torch, math
|
||||
from diffsynth.lora import GeneralLoRALoader
|
||||
from diffsynth.models.lora import FluxLoRAFromCivitai
|
||||
from . import GeneralLoRALoader
|
||||
from ..utils import ModelConfig
|
||||
from ..models.utils import load_state_dict
|
||||
from typing import Union
|
||||
|
||||
|
||||
class FluxLoRALoader(GeneralLoRALoader):
|
||||
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
|
||||
self.diffusers_rename_dict = {
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||
}
|
||||
|
||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||
super().load(model, state_dict_lora, alpha)
|
||||
|
||||
def convert_state_dict(self, state_dict):
|
||||
# TODO: support other lora format
|
||||
rename_dict = {
|
||||
self.civitai_rename_dict = {
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||
@@ -40,25 +80,57 @@ class FluxLoRALoader(GeneralLoRALoader):
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||
}
|
||||
def guess_block_id(name):
|
||||
names = name.split("_")
|
||||
for i in names:
|
||||
if i.isdigit():
|
||||
return i, name.replace(f"_{i}_", "_blockid_")
|
||||
|
||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||
super().load(model, state_dict_lora, alpha)
|
||||
|
||||
|
||||
def convert_state_dict(self,state_dict):
|
||||
|
||||
def guess_block_id(name,model_resource):
|
||||
if model_resource == 'civitai':
|
||||
names = name.split("_")
|
||||
for i in names:
|
||||
if i.isdigit():
|
||||
return i, name.replace(f"_{i}_", "_blockid_")
|
||||
if model_resource == 'diffusers':
|
||||
names = name.split(".")
|
||||
for i in names:
|
||||
if i.isdigit():
|
||||
return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
|
||||
return None, None
|
||||
|
||||
def guess_resource(state_dict):
|
||||
for k in state_dict:
|
||||
if "lora_unet_" in k:
|
||||
return 'civitai'
|
||||
elif k.startswith("transformer."):
|
||||
return 'diffusers'
|
||||
else:
|
||||
None
|
||||
|
||||
model_resource = guess_resource(state_dict)
|
||||
if model_resource is None:
|
||||
return state_dict
|
||||
|
||||
rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
|
||||
def guess_alpha(state_dict):
|
||||
for name, param in state_dict.items():
|
||||
if ".alpha" in name:
|
||||
name_ = name.replace(".alpha", ".lora_down.weight")
|
||||
if name_ in state_dict:
|
||||
lora_alpha = param.item() / state_dict[name_].shape[0]
|
||||
lora_alpha = math.sqrt(lora_alpha)
|
||||
return lora_alpha
|
||||
return 1
|
||||
for name, param in state_dict.items():
|
||||
if ".alpha" in name:
|
||||
for suffix in [".lora_down.weight", ".lora_A.weight"]:
|
||||
name_ = name.replace(".alpha", suffix)
|
||||
if name_ in state_dict:
|
||||
lora_alpha = param.item() / state_dict[name_].shape[0]
|
||||
lora_alpha = math.sqrt(lora_alpha)
|
||||
return lora_alpha
|
||||
|
||||
return 1
|
||||
|
||||
alpha = guess_alpha(state_dict)
|
||||
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
block_id, source_name = guess_block_id(name)
|
||||
block_id, source_name = guess_block_id(name,model_resource)
|
||||
if alpha != 1:
|
||||
param *= alpha
|
||||
if source_name in rename_dict:
|
||||
@@ -67,6 +139,72 @@ class FluxLoRALoader(GeneralLoRALoader):
|
||||
state_dict_[target_name] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
|
||||
if model_resource == 'diffusers':
|
||||
for name in list(state_dict_.keys()):
|
||||
if "single_blocks." in name and ".a_to_q." in name:
|
||||
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||
if mlp is None:
|
||||
dim = 4
|
||||
if 'lora_A' in name:
|
||||
dim = 1
|
||||
mlp = torch.zeros(dim * state_dict_[name].shape[0],
|
||||
*state_dict_[name].shape[1:],
|
||||
dtype=state_dict_[name].dtype)
|
||||
else:
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||
if 'lora_A' in name:
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
elif 'lora_B' in name:
|
||||
d, r = state_dict_[name].shape
|
||||
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
|
||||
param[:d, :r] = state_dict_.pop(name)
|
||||
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
|
||||
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
|
||||
param[3*d:, 3*r:] = mlp
|
||||
else:
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||
state_dict_[name_] = param
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
concat_dim = 0
|
||||
if 'lora_A' in name:
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
elif 'lora_B' in name:
|
||||
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||
d, r = origin.shape
|
||||
# print(d, r)
|
||||
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
|
||||
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
|
||||
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
|
||||
else:
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
return state_dict_
|
||||
|
||||
|
||||
@@ -140,3 +278,47 @@ class FluxLoraPatcherStateDictConverter:
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
|
||||
class FluxLoRAFuser:
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
def Matrix_Decomposition_lowrank(self, A, k):
|
||||
U, S, V = torch.svd_lowrank(A.float(), q=k)
|
||||
S_k = torch.diag(S[:k])
|
||||
U_hat = U @ S_k
|
||||
return U_hat, V.t()
|
||||
|
||||
def LoRA_State_Dicts_Decomposition(self, lora_state_dicts=[], q=4):
|
||||
lora_1 = lora_state_dicts[0]
|
||||
state_dict_ = {}
|
||||
for k,v in lora_1.items():
|
||||
if 'lora_A.' in k:
|
||||
lora_B_name = k.replace('lora_A.', 'lora_B.')
|
||||
lora_B = lora_1[lora_B_name]
|
||||
weight = torch.mm(lora_B, v)
|
||||
for lora_dict in lora_state_dicts[1:]:
|
||||
lora_A_ = lora_dict[k]
|
||||
lora_B_ = lora_dict[lora_B_name]
|
||||
weight_ = torch.mm(lora_B_, lora_A_)
|
||||
weight += weight_
|
||||
new_B, new_A = self.Matrix_Decomposition_lowrank(weight, q)
|
||||
state_dict_[lora_B_name] = new_B.to(dtype=torch.bfloat16)
|
||||
state_dict_[k] = new_A.to(dtype=torch.bfloat16)
|
||||
return state_dict_
|
||||
|
||||
def __call__(self, lora_configs: list[Union[ModelConfig, str]]):
|
||||
loras = []
|
||||
loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
for lora_config in lora_configs:
|
||||
if isinstance(lora_config, str):
|
||||
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora_config.download_if_necessary()
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = loader.convert_state_dict(lora)
|
||||
loras.append(lora)
|
||||
lora = self.LoRA_State_Dicts_Decomposition(loras)
|
||||
return lora
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
||||
from einops import rearrange
|
||||
from .tiler import TileWorker
|
||||
from .utils import init_weights_on_device
|
||||
from .utils import init_weights_on_device, hash_state_dict_keys
|
||||
|
||||
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||
batch_size, num_tokens = hidden_states.shape[0:2]
|
||||
@@ -662,6 +662,9 @@ class FluxDiTStateDictConverter:
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
if hash_state_dict_keys(state_dict, with_shape=True) in ["3e6c61b0f9471135fc9c6d6a98e98b6d", "63c969fd37cce769a90aa781fbff5f81"]:
|
||||
dit_state_dict = {key.replace("pipe.dit.", ""): value for key, value in state_dict.items() if key.startswith('pipe.dit.')}
|
||||
return dit_state_dict
|
||||
rename_dict = {
|
||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
|
||||
@@ -104,6 +104,7 @@ class InfiniteYouImageProjector(nn.Module):
|
||||
def forward(self, x):
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
latents = latents.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
|
||||
111
diffsynth/models/flux_lora_encoder.py
Normal file
111
diffsynth/models/flux_lora_encoder.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import torch
|
||||
from .sd_text_encoder import CLIPEncoderLayer
|
||||
|
||||
|
||||
class LoRALayerBlock(torch.nn.Module):
|
||||
def __init__(self, L, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
|
||||
self.layer_norm = torch.nn.LayerNorm(dim_out)
|
||||
|
||||
def forward(self, lora_A, lora_B):
|
||||
x = self.x @ lora_A.T @ lora_B.T
|
||||
x = self.layer_norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class LoRAEmbedder(torch.nn.Module):
|
||||
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
|
||||
super().__init__()
|
||||
if lora_patterns is None:
|
||||
lora_patterns = self.default_lora_patterns()
|
||||
|
||||
model_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1])
|
||||
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||
|
||||
proj_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
layer_type, dim = lora_pattern["type"], lora_pattern["dim"]
|
||||
if layer_type not in proj_dict:
|
||||
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim)
|
||||
self.proj_dict = torch.nn.ModuleDict(proj_dict)
|
||||
|
||||
self.lora_patterns = lora_patterns
|
||||
|
||||
|
||||
def default_lora_patterns(self):
|
||||
lora_patterns = []
|
||||
lora_dict = {
|
||||
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
|
||||
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
|
||||
}
|
||||
for i in range(19):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix],
|
||||
"type": suffix,
|
||||
})
|
||||
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
|
||||
for i in range(38):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"single_blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix],
|
||||
"type": suffix,
|
||||
})
|
||||
return lora_patterns
|
||||
|
||||
def forward(self, lora):
|
||||
lora_emb = []
|
||||
for lora_pattern in self.lora_patterns:
|
||||
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||
lora_A = lora[name + ".lora_A.default.weight"]
|
||||
lora_B = lora[name + ".lora_B.default.weight"]
|
||||
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||
lora_emb.append(lora_out)
|
||||
lora_emb = torch.concat(lora_emb, dim=1)
|
||||
return lora_emb
|
||||
|
||||
|
||||
class FluxLoRAEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1):
|
||||
super().__init__()
|
||||
self.num_embeds_per_lora = num_embeds_per_lora
|
||||
# embedder
|
||||
self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim)
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)])
|
||||
|
||||
# special embedding
|
||||
self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim))
|
||||
self.num_special_embeds = num_special_embeds
|
||||
|
||||
# final layer
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
self.final_linear = torch.nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
def forward(self, lora):
|
||||
lora_embeds = self.embedder(lora)
|
||||
special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device)
|
||||
embeds = torch.concat([special_embeds, lora_embeds], dim=1)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds)
|
||||
embeds = embeds[:, :self.num_special_embeds]
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
embeds = self.final_linear(embeds)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxLoRAEncoderStateDictConverter()
|
||||
|
||||
|
||||
class FluxLoRAEncoderStateDictConverter:
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
60
diffsynth/models/flux_value_control.py
Normal file
60
diffsynth/models/flux_value_control.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
|
||||
|
||||
class MultiValueEncoder(torch.nn.Module):
|
||||
def __init__(self, encoders=()):
|
||||
super().__init__()
|
||||
self.encoders = torch.nn.ModuleList(encoders)
|
||||
|
||||
def __call__(self, values, dtype):
|
||||
emb = []
|
||||
for encoder, value in zip(self.encoders, values):
|
||||
if value is not None:
|
||||
value = value.unsqueeze(0)
|
||||
emb.append(encoder(value, dtype))
|
||||
emb = torch.concat(emb, dim=0)
|
||||
return emb
|
||||
|
||||
|
||||
class SingleValueEncoder(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
|
||||
super().__init__()
|
||||
self.prefer_len = prefer_len
|
||||
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||
self.prefer_value_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
self.positional_embedding = torch.nn.Parameter(
|
||||
torch.randn(self.prefer_len, dim_out)
|
||||
)
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
last_linear = self.prefer_value_embedder[-1]
|
||||
torch.nn.init.zeros_(last_linear.weight)
|
||||
torch.nn.init.zeros_(last_linear.bias)
|
||||
|
||||
def forward(self, value, dtype):
|
||||
value = value * 1000
|
||||
emb = self.prefer_proj(value).to(dtype)
|
||||
emb = self.prefer_value_embedder(emb).squeeze(0)
|
||||
base_embeddings = emb.expand(self.prefer_len, -1)
|
||||
positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
|
||||
learned_embeddings = base_embeddings + positional_embedding
|
||||
return learned_embeddings
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SingleValueEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SingleValueEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -277,7 +277,7 @@ class FluxLoRAConverter:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, alpha=1.0):
|
||||
def align_to_opensource_format(state_dict, alpha=None):
|
||||
prefix_rename_dict = {
|
||||
"single_blocks": "lora_unet_single_blocks",
|
||||
"blocks": "lora_unet_double_blocks",
|
||||
@@ -316,7 +316,8 @@ class FluxLoRAConverter:
|
||||
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
|
||||
state_dict_[rename] = param
|
||||
if rename.endswith("lora_up.weight"):
|
||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
|
||||
lora_alpha = alpha if alpha is not None else param.shape[-1]
|
||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((lora_alpha,))[0]
|
||||
return state_dict_
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -426,7 +426,7 @@ class ModelManager:
|
||||
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None):
|
||||
fetched_models = []
|
||||
fetched_model_paths = []
|
||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||
@@ -440,12 +440,25 @@ class ModelManager:
|
||||
return None
|
||||
if len(fetched_models) == 1:
|
||||
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
||||
model = fetched_models[0]
|
||||
path = fetched_model_paths[0]
|
||||
else:
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||
if index is None:
|
||||
model = fetched_models[0]
|
||||
path = fetched_model_paths[0]
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||
elif isinstance(index, int):
|
||||
model = fetched_models[:index]
|
||||
path = fetched_model_paths[:index]
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.")
|
||||
else:
|
||||
model = fetched_models
|
||||
path = fetched_model_paths
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.")
|
||||
if require_model_path:
|
||||
return fetched_models[0], fetched_model_paths[0]
|
||||
return model, path
|
||||
else:
|
||||
return fetched_models[0]
|
||||
return model
|
||||
|
||||
|
||||
def to(self, device):
|
||||
|
||||
161
diffsynth/models/nexus_gen.py
Normal file
161
diffsynth/models/nexus_gen.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class NexusGenAutoregressiveModel(torch.nn.Module):
|
||||
def __init__(self, max_length=1024, max_pixels=262640):
|
||||
super(NexusGenAutoregressiveModel, self).__init__()
|
||||
from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration
|
||||
from transformers import Qwen2_5_VLConfig
|
||||
self.max_length = max_length
|
||||
self.max_pixels = max_pixels
|
||||
model_config = Qwen2_5_VLConfig(**{
|
||||
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"auto_map": {
|
||||
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
||||
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
||||
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
||||
},
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"image_token_id": 151655,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"max_position_embeddings": 128000,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"pad_token_id": 151643,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [
|
||||
16,
|
||||
24,
|
||||
24
|
||||
],
|
||||
"rope_type": "default",
|
||||
"type": "default"
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": 32768,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.49.0",
|
||||
"use_cache": False,
|
||||
"use_sliding_window": False,
|
||||
"video_token_id": 151656,
|
||||
"vision_config": {
|
||||
"hidden_size": 1280,
|
||||
"in_chans": 3,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"spatial_patch_size": 14,
|
||||
"tokens_per_second": 2,
|
||||
"torch_dtype": "bfloat16"
|
||||
},
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 152064
|
||||
})
|
||||
self.model = Qwen2_5_VLForConditionalGeneration(model_config)
|
||||
self.processor = None
|
||||
|
||||
|
||||
def load_processor(self, path):
|
||||
from .nexus_gen_ar_model import Qwen2_5_VLProcessor
|
||||
self.processor = Qwen2_5_VLProcessor.from_pretrained(path)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return NexusGenAutoregressiveModelStateDictConverter()
|
||||
|
||||
def bound_image(self, image, max_pixels=262640):
|
||||
from qwen_vl_utils import smart_resize
|
||||
resized_height, resized_width = smart_resize(
|
||||
image.height,
|
||||
image.width,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
return image.resize((resized_width, resized_height))
|
||||
|
||||
def get_editing_msg(self, instruction):
|
||||
if '<image>' not in instruction:
|
||||
instruction = '<image> ' + instruction
|
||||
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: <image>"}]
|
||||
return messages
|
||||
|
||||
def get_generation_msg(self, instruction):
|
||||
instruction = "Generate an image according to the following description: {}".format(instruction)
|
||||
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: <image>"}]
|
||||
return messages
|
||||
|
||||
def forward(self, instruction, ref_image=None, num_img_tokens=81):
|
||||
"""
|
||||
Generate target embeddings for the given instruction and reference image.
|
||||
"""
|
||||
if ref_image is not None:
|
||||
messages = self.get_editing_msg(instruction)
|
||||
images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
||||
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
||||
else:
|
||||
messages = self.get_generation_msg(instruction)
|
||||
images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
||||
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
||||
|
||||
return output_image_embeddings
|
||||
|
||||
def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81):
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
||||
text = text.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=images,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
input_embeds = model.model.embed_tokens(inputs['input_ids'])
|
||||
image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
|
||||
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
|
||||
input_image_embeds = image_embeds[:-num_img_tokens]
|
||||
|
||||
image_mask = inputs['input_ids'] == model.config.image_token_id
|
||||
indices = image_mask.cumsum(dim=1)
|
||||
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
|
||||
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
|
||||
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
|
||||
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
|
||||
|
||||
image_prefill_embeds = model.image_prefill_embeds(
|
||||
torch.arange(81, device=model.device).long()
|
||||
)
|
||||
input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)
|
||||
|
||||
position_ids, _ = model.get_rope_index(
|
||||
inputs['input_ids'],
|
||||
inputs['image_grid_thw'],
|
||||
attention_mask=inputs['attention_mask'])
|
||||
position_ids = position_ids.contiguous()
|
||||
outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
|
||||
output_image_embeddings = outputs.image_embeddings[:, :-1, :]
|
||||
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
|
||||
return output_image_embeddings, input_image_embeds, inputs['image_grid_thw']
|
||||
|
||||
|
||||
class NexusGenAutoregressiveModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {"model." + key: value for key, value in state_dict.items()}
|
||||
return state_dict
|
||||
1143
diffsynth/models/nexus_gen_ar_model.py
Normal file
1143
diffsynth/models/nexus_gen_ar_model.py
Normal file
File diff suppressed because it is too large
Load Diff
417
diffsynth/models/nexus_gen_projector.py
Normal file
417
diffsynth/models/nexus_gen_projector.py
Normal file
@@ -0,0 +1,417 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
mrope_section = mrope_section * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
from transformers.modeling_rope_utils import _compute_default_rope_parameters
|
||||
self.rope_init_fn = _compute_default_rope_parameters
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
# Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids
|
||||
# So we expand the inv_freq to shape (3, ...)
|
||||
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
||||
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class Qwen2_5_VLAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.rope_scaling = config.rope_scaling
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
# Fix precision issues in Qwen2-VL float16 inference
|
||||
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
||||
if query_states.dtype == torch.float16:
|
||||
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
from transformers.activations import ACT2FN
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class Qwen2RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Qwen2RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class Qwen2_5_VLDecoderLayer(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
|
||||
|
||||
self.mlp = Qwen2MLP(config)
|
||||
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NexusGenImageEmbeddingMerger(nn.Module):
|
||||
def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'):
|
||||
super().__init__()
|
||||
from transformers import Qwen2_5_VLConfig
|
||||
from transformers.activations import ACT2FN
|
||||
config = Qwen2_5_VLConfig(**{
|
||||
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"auto_map": {
|
||||
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
||||
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
||||
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
||||
},
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"image_token_id": 151655,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"max_position_embeddings": 128000,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"pad_token_id": 151643,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [
|
||||
16,
|
||||
24,
|
||||
24
|
||||
],
|
||||
"rope_type": "default",
|
||||
"type": "default"
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": 32768,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.49.0",
|
||||
"use_cache": False,
|
||||
"use_sliding_window": False,
|
||||
"video_token_id": 151656,
|
||||
"vision_config": {
|
||||
"hidden_size": 1280,
|
||||
"in_chans": 3,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"spatial_patch_size": 14,
|
||||
"tokens_per_second": 2,
|
||||
"torch_dtype": "bfloat16"
|
||||
},
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 152064
|
||||
})
|
||||
self.config = config
|
||||
self.num_layers = num_layers
|
||||
self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)])
|
||||
self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps),
|
||||
nn.Linear(config.hidden_size, out_channel * expand_ratio),
|
||||
Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps),
|
||||
ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel),
|
||||
Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps))
|
||||
self.base_grid = torch.tensor([[1, 72, 72]], device=device)
|
||||
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device)
|
||||
|
||||
def get_position_ids(self, image_grid_thw):
|
||||
"""
|
||||
Generates position ids for the input embeddings grid.
|
||||
modified from the qwen2_vl mrope.
|
||||
"""
|
||||
batch_size = image_grid_thw.shape[0]
|
||||
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
||||
t, h, w = (
|
||||
image_grid_thw[0][0],
|
||||
image_grid_thw[0][1],
|
||||
image_grid_thw[0][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t.item(),
|
||||
h.item() // spatial_merge_size,
|
||||
w.item() // spatial_merge_size,
|
||||
)
|
||||
scale_h = self.base_grid[0][1].item() / h.item()
|
||||
scale_w = self.base_grid[0][2].item() / w.item()
|
||||
|
||||
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
||||
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
|
||||
time_tensor = expanded_range * self.config.vision_config.tokens_per_second
|
||||
t_index = time_tensor.long().flatten().to(image_grid_thw.device)
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w
|
||||
# 3, B, L
|
||||
position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2)
|
||||
return position_ids
|
||||
|
||||
def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None):
|
||||
position_ids = self.get_position_ids(embeds_grid)
|
||||
hidden_states = embeds
|
||||
if ref_embeds is not None:
|
||||
position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid)
|
||||
position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1)
|
||||
hidden_states = torch.cat((embeds, ref_embeds), dim=1)
|
||||
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, position_embeddings)
|
||||
|
||||
hidden_states = self.projector(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return NexusGenMergerStateDictConverter()
|
||||
|
||||
|
||||
class NexusGenMergerStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')}
|
||||
return merger_state_dict
|
||||
|
||||
|
||||
class NexusGenAdapter(nn.Module):
|
||||
"""
|
||||
Adapter for Nexus-Gen generation decoder.
|
||||
"""
|
||||
def __init__(self, input_dim=3584, output_dim=4096):
|
||||
super(NexusGenAdapter, self).__init__()
|
||||
self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim),
|
||||
nn.LayerNorm(output_dim), nn.ReLU(),
|
||||
nn.Linear(output_dim, output_dim),
|
||||
nn.LayerNorm(output_dim))
|
||||
|
||||
def forward(self, x):
|
||||
return self.adapter(x)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return NexusGenAdapterStateDictConverter()
|
||||
|
||||
|
||||
class NexusGenAdapterStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')}
|
||||
return adapter_state_dict
|
||||
@@ -162,7 +162,7 @@ class TimestepEmbedder(nn.Module):
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(
|
||||
t, self.frequency_embedding_size, self.max_period
|
||||
).type(self.mlp[0].weight.dtype) # type: ignore
|
||||
).type(t.dtype) # type: ignore
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
@@ -656,7 +656,7 @@ class Qwen2Connector(torch.nn.Module):
|
||||
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
||||
x_mean = (x * mask_float).sum(
|
||||
dim=1
|
||||
) / mask_float.sum(dim=1) * (1 + self.scale_factor)
|
||||
) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(dtype=x.dtype, device=x.device))
|
||||
|
||||
global_out=self.global_proj_out(x_mean)
|
||||
encoder_hidden_states = self.S(x,t,mask)
|
||||
|
||||
@@ -71,7 +71,7 @@ def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device=device) as f:
|
||||
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
|
||||
@@ -212,9 +212,16 @@ class DiTBlock(nn.Module):
|
||||
self.gate = GateModule()
|
||||
|
||||
def forward(self, x, context, t_mod, freqs):
|
||||
has_seq = len(t_mod.shape) == 4
|
||||
chunk_dim = 2 if has_seq else 1
|
||||
# msa: multi-head self-attention mlp: multi-layer perceptron
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim)
|
||||
if has_seq:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
|
||||
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
|
||||
)
|
||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
@@ -253,8 +260,12 @@ class Head(nn.Module):
|
||||
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||
|
||||
def forward(self, x, t_mod):
|
||||
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
||||
if len(t_mod.shape) == 3:
|
||||
shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2)
|
||||
x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)))
|
||||
else:
|
||||
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
||||
return x
|
||||
|
||||
|
||||
@@ -276,12 +287,20 @@ class WanModel(torch.nn.Module):
|
||||
has_ref_conv: bool = False,
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
seperated_timestep: bool = False,
|
||||
require_vae_embedding: bool = True,
|
||||
require_clip_embedding: bool = True,
|
||||
fuse_vae_embedding_in_latents: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.freq_dim = freq_dim
|
||||
self.has_image_input = has_image_input
|
||||
self.patch_size = patch_size
|
||||
self.seperated_timestep = seperated_timestep
|
||||
self.require_vae_embedding = require_vae_embedding
|
||||
self.require_clip_embedding = require_clip_embedding
|
||||
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
@@ -659,6 +678,41 @@ class WanModelStateDictConverter:
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316":
|
||||
# Wan-AI/Wan2.2-TI2V-5B
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 3072,
|
||||
"ffn_dim": 14336,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 48,
|
||||
"num_heads": 24,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"seperated_timestep": True,
|
||||
"require_clip_embedding": False,
|
||||
"require_vae_embedding": False,
|
||||
"fuse_vae_embedding_in_latents": True,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
|
||||
# Wan-AI/Wan2.2-I2V-A14B
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
@@ -195,6 +195,75 @@ class Resample(nn.Module):
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
|
||||
|
||||
def patchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(x,
|
||||
"b c f (h q) (w r) -> b (c r q) f h w",
|
||||
q=patch_size,
|
||||
r=patch_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(x,
|
||||
"b (c r q) f h w -> b c f (h q) (w r)",
|
||||
q=patch_size,
|
||||
r=patch_size)
|
||||
return x
|
||||
|
||||
|
||||
class Resample38(Resample):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in (
|
||||
"none",
|
||||
"upsample2d",
|
||||
"upsample3d",
|
||||
"downsample2d",
|
||||
"downsample3d",
|
||||
)
|
||||
super(Resample, self).__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == "upsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||
nn.Conv2d(dim, dim, 3, padding=1),
|
||||
)
|
||||
elif mode == "upsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||
nn.Conv2d(dim, dim, 3, padding=1),
|
||||
)
|
||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
elif mode == "downsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
|
||||
)
|
||||
elif mode == "downsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
|
||||
)
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
||||
)
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
@@ -273,6 +342,178 @@ class AttentionBlock(nn.Module):
|
||||
return x + identity
|
||||
|
||||
|
||||
class AvgDown3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
factor_t,
|
||||
factor_s=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||
|
||||
assert in_channels * self.factor % out_channels == 0
|
||||
self.group_size = in_channels * self.factor // out_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
||||
pad = (0, 0, 0, 0, pad_t, 0)
|
||||
x = F.pad(x, pad)
|
||||
B, C, T, H, W = x.shape
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
T // self.factor_t,
|
||||
self.factor_t,
|
||||
H // self.factor_s,
|
||||
self.factor_s,
|
||||
W // self.factor_s,
|
||||
self.factor_s,
|
||||
)
|
||||
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
||||
x = x.view(
|
||||
B,
|
||||
C * self.factor,
|
||||
T // self.factor_t,
|
||||
H // self.factor_s,
|
||||
W // self.factor_s,
|
||||
)
|
||||
x = x.view(
|
||||
B,
|
||||
self.out_channels,
|
||||
self.group_size,
|
||||
T // self.factor_t,
|
||||
H // self.factor_s,
|
||||
W // self.factor_s,
|
||||
)
|
||||
x = x.mean(dim=2)
|
||||
return x
|
||||
|
||||
|
||||
class DupUp3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
factor_t,
|
||||
factor_s=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||
|
||||
assert out_channels * self.factor % in_channels == 0
|
||||
self.repeats = out_channels * self.factor // in_channels
|
||||
|
||||
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
||||
x = x.repeat_interleave(self.repeats, dim=1)
|
||||
x = x.view(
|
||||
x.size(0),
|
||||
self.out_channels,
|
||||
self.factor_t,
|
||||
self.factor_s,
|
||||
self.factor_s,
|
||||
x.size(2),
|
||||
x.size(3),
|
||||
x.size(4),
|
||||
)
|
||||
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
||||
x = x.view(
|
||||
x.size(0),
|
||||
self.out_channels,
|
||||
x.size(2) * self.factor_t,
|
||||
x.size(4) * self.factor_s,
|
||||
x.size(6) * self.factor_s,
|
||||
)
|
||||
if first_chunk:
|
||||
x = x[:, :, self.factor_t - 1 :, :, :]
|
||||
return x
|
||||
|
||||
|
||||
class Down_ResidualBlock(nn.Module):
|
||||
def __init__(
|
||||
self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Shortcut path with downsample
|
||||
self.avg_shortcut = AvgDown3D(
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_downsample else 1,
|
||||
factor_s=2 if down_flag else 1,
|
||||
)
|
||||
|
||||
# Main path with residual blocks and downsample
|
||||
downsamples = []
|
||||
for _ in range(mult):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
in_dim = out_dim
|
||||
|
||||
# Add the final downsample block
|
||||
if down_flag:
|
||||
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||
downsamples.append(Resample38(out_dim, mode=mode))
|
||||
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
x_copy = x.clone()
|
||||
for module in self.downsamples:
|
||||
x = module(x, feat_cache, feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy)
|
||||
|
||||
|
||||
class Up_ResidualBlock(nn.Module):
|
||||
def __init__(
|
||||
self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False
|
||||
):
|
||||
super().__init__()
|
||||
# Shortcut path with upsample
|
||||
if up_flag:
|
||||
self.avg_shortcut = DupUp3D(
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_upsample else 1,
|
||||
factor_s=2 if up_flag else 1,
|
||||
)
|
||||
else:
|
||||
self.avg_shortcut = None
|
||||
|
||||
# Main path with residual blocks and upsample
|
||||
upsamples = []
|
||||
for _ in range(mult):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
in_dim = out_dim
|
||||
|
||||
# Add the final upsample block
|
||||
if up_flag:
|
||||
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
||||
upsamples.append(Resample38(out_dim, mode=mode))
|
||||
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
x_main = x.clone()
|
||||
for module in self.upsamples:
|
||||
x_main = module(x_main, feat_cache, feat_idx)
|
||||
if self.avg_shortcut is not None:
|
||||
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||
return x_main + x_shortcut
|
||||
else:
|
||||
return x_main
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@@ -376,6 +617,122 @@ class Encoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Encoder3d_38(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
t_down_flag = (
|
||||
temperal_downsample[i] if i < len(temperal_downsample) else False
|
||||
)
|
||||
downsamples.append(
|
||||
Down_ResidualBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
dropout=dropout,
|
||||
mult=num_res_blocks,
|
||||
temperal_downsample=t_down_flag,
|
||||
down_flag=i != len(dim_mult) - 1,
|
||||
)
|
||||
)
|
||||
scale /= 2.0
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(out_dim, out_dim, dropout),
|
||||
AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout),
|
||||
)
|
||||
|
||||
# # output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :]
|
||||
.unsqueeze(2)
|
||||
.to(cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@@ -481,10 +838,112 @@ class Decoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class Decoder3d_38(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout))
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
|
||||
upsamples.append(
|
||||
Up_ResidualBlock(in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
dropout=dropout,
|
||||
mult=num_res_blocks + 1,
|
||||
temperal_upsample=t_up_flag,
|
||||
up_flag=i != len(dim_mult) - 1))
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 12, 3, padding=1))
|
||||
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, first_chunk)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :]
|
||||
.unsqueeze(2)
|
||||
.to(cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if check_is_instance(m, CausalConv3d):
|
||||
if isinstance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@@ -616,6 +1075,7 @@ class WanVideoVAE(nn.Module):
|
||||
# init model
|
||||
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
|
||||
self.upsampling_factor = 8
|
||||
self.z_dim = z_dim
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
@@ -711,7 +1171,7 @@ class WanVideoVAE(nn.Module):
|
||||
|
||||
out_T = (T + 3) // 4
|
||||
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
values = torch.zeros((1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
|
||||
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
||||
@@ -762,8 +1222,8 @@ class WanVideoVAE(nn.Module):
|
||||
for video in videos:
|
||||
video = video.unsqueeze(0)
|
||||
if tiled:
|
||||
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
|
||||
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
|
||||
tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor)
|
||||
tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor)
|
||||
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
||||
else:
|
||||
hidden_state = self.single_encode(video, device)
|
||||
@@ -798,3 +1258,119 @@ class WanVideoVAEStateDictConverter:
|
||||
for name in state_dict:
|
||||
state_dict_['model.' + name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
class VideoVAE38_(VideoVAE_):
|
||||
|
||||
def __init__(self,
|
||||
dim=160,
|
||||
z_dim=48,
|
||||
dec_dim=256,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super(VideoVAE_, self).__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d_38(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d_38(dec_dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
|
||||
def encode(self, x, scale):
|
||||
self.clear_cache()
|
||||
x = patchify(x, patch_size=2)
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
self.clear_cache()
|
||||
return mu
|
||||
|
||||
|
||||
def decode(self, z, scale):
|
||||
self.clear_cache()
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=z.dtype, device=z.device)
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
first_chunk=True)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
out = unpatchify(out, patch_size=2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
|
||||
class WanVideoVAE38(WanVideoVAE):
|
||||
|
||||
def __init__(self, z_dim=48, dim=160):
|
||||
super(WanVideoVAE, self).__init__()
|
||||
|
||||
mean = [
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
||||
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
||||
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
||||
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667
|
||||
]
|
||||
std = [
|
||||
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
||||
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
||||
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
||||
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
||||
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
||||
]
|
||||
self.mean = torch.tensor(mean)
|
||||
self.std = torch.tensor(std)
|
||||
self.scale = [self.mean, 1.0 / self.std]
|
||||
|
||||
# init model
|
||||
self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False)
|
||||
self.upsampling_factor = 16
|
||||
self.z_dim = z_dim
|
||||
|
||||
@@ -18,12 +18,15 @@ from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEnc
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
from ..models.flux_controlnet import FluxControlNet
|
||||
from ..models.flux_ipadapter import FluxIpAdapter
|
||||
from ..models.flux_value_control import MultiValueEncoder
|
||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||
from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock
|
||||
from ..models.tiler import FastTileWorker
|
||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher
|
||||
from ..models.nexus_gen import NexusGenAutoregressiveModel
|
||||
from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher, FluxLoRAFuser
|
||||
|
||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||
from ..models.flux_dit import RMSNorm
|
||||
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
|
||||
@@ -93,9 +96,14 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.ipadapter_image_encoder = None
|
||||
self.qwenvl = None
|
||||
self.step1x_connector: Qwen2Connector = None
|
||||
self.nexus_gen: NexusGenAutoregressiveModel = None
|
||||
self.nexus_gen_generation_adapter: NexusGenAdapter = None
|
||||
self.nexus_gen_editing_adapter: NexusGenImageEmbeddingMerger = None
|
||||
self.value_controller: MultiValueEncoder = None
|
||||
self.infinityou_processor: InfinitYou = None
|
||||
self.image_proj_model: InfiniteYouImageProjector = None
|
||||
self.lora_patcher: FluxLoraPatcher = None
|
||||
self.lora_encoder: FluxLoRAEncoder = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher")
|
||||
self.units = [
|
||||
@@ -110,9 +118,12 @@ class FluxImagePipeline(BasePipeline):
|
||||
FluxImageUnit_ControlNet(),
|
||||
FluxImageUnit_IPAdapter(),
|
||||
FluxImageUnit_EntityControl(),
|
||||
FluxImageUnit_NexusGen(),
|
||||
FluxImageUnit_TeaCache(),
|
||||
FluxImageUnit_Flex(),
|
||||
FluxImageUnit_Step1x(),
|
||||
FluxImageUnit_ValueControl(),
|
||||
FluxImageUnit_LoRAEncode(),
|
||||
]
|
||||
self.model_fn = model_fn_flux_image
|
||||
|
||||
@@ -120,18 +131,20 @@ class FluxImagePipeline(BasePipeline):
|
||||
def load_lora(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lora_config: Union[ModelConfig, str],
|
||||
lora_config: Union[ModelConfig, str] = None,
|
||||
alpha=1,
|
||||
hotload=False,
|
||||
local_model_path="./models",
|
||||
skip_download=False
|
||||
state_dict=None,
|
||||
):
|
||||
if isinstance(lora_config, str):
|
||||
lora_config = ModelConfig(path=lora_config)
|
||||
if state_dict is None:
|
||||
if isinstance(lora_config, str):
|
||||
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora_config.download_if_necessary()
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
||||
lora = state_dict
|
||||
loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = loader.convert_state_dict(lora)
|
||||
if hotload:
|
||||
for name, module in module.named_modules():
|
||||
@@ -145,19 +158,21 @@ class FluxImagePipeline(BasePipeline):
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def enable_lora_patcher(self):
|
||||
if not (hasattr(self, "vram_management_enabled") and self.vram_management_enabled):
|
||||
print("Please enable VRAM management using `enable_vram_management()` before `enable_lora_patcher()`.")
|
||||
return
|
||||
if self.lora_patcher is None:
|
||||
print("Please load lora patcher models before `enable_lora_patcher()`.")
|
||||
return
|
||||
for name, module in self.dit.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
merger_name = name.replace(".", "___")
|
||||
if merger_name in self.lora_patcher.model_dict:
|
||||
module.lora_merger = self.lora_patcher.model_dict[merger_name]
|
||||
|
||||
def load_loras(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lora_configs: list[Union[ModelConfig, str]],
|
||||
alpha=1,
|
||||
hotload=False,
|
||||
extra_fused_lora=False,
|
||||
):
|
||||
for lora_config in lora_configs:
|
||||
self.load_lora(module, lora_config, hotload=hotload, alpha=alpha)
|
||||
if extra_fused_lora:
|
||||
lora_fuser = FluxLoRAFuser(device="cuda", torch_dtype=torch.bfloat16)
|
||||
fused_lora = lora_fuser(lora_configs)
|
||||
self.load_lora(module, state_dict=fused_lora, hotload=hotload, alpha=alpha)
|
||||
|
||||
|
||||
def clear_lora(self):
|
||||
for name, module in self.named_modules():
|
||||
@@ -182,22 +197,19 @@ class FluxImagePipeline(BasePipeline):
|
||||
return loss
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
|
||||
self.vram_management_enabled = True
|
||||
if num_persistent_param_in_dit is not None:
|
||||
vram_limit = None
|
||||
else:
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
if self.text_encoder_1 is not None:
|
||||
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||
def _enable_vram_management_with_default_config(self, model, vram_limit):
|
||||
if model is not None:
|
||||
dtype = next(iter(model.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_1,
|
||||
model,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
LoRALayerBlock: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -209,7 +221,52 @@ class FluxImagePipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
|
||||
|
||||
def enable_lora_magic(self):
|
||||
if self.dit is not None:
|
||||
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device=self.device,
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=None,
|
||||
)
|
||||
if self.lora_patcher is not None:
|
||||
for name, module in self.dit.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
merger_name = name.replace(".", "___")
|
||||
if merger_name in self.lora_patcher.model_dict:
|
||||
module.lora_merger = self.lora_patcher.model_dict[merger_name]
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
|
||||
self.vram_management_enabled = True
|
||||
if num_persistent_param_in_dit is not None:
|
||||
vram_limit = None
|
||||
else:
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
|
||||
# Default config
|
||||
default_vram_management_models = ["text_encoder_1", "vae_decoder", "vae_encoder", "controlnet", "image_proj_model", "ipadapter", "lora_patcher", "value_controller", "step1x_connector", "lora_encoder"]
|
||||
for model_name in default_vram_management_models:
|
||||
self._enable_vram_management_with_default_config(getattr(self, model_name), vram_limit)
|
||||
|
||||
# Special config
|
||||
if self.text_encoder_2 is not None:
|
||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_2,
|
||||
@@ -258,14 +315,18 @@ class FluxImagePipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae_decoder is not None:
|
||||
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
||||
if self.ipadapter_image_encoder is not None:
|
||||
from transformers.models.siglip.modeling_siglip import SiglipVisionEmbeddings, SiglipEncoder, SiglipMultiheadAttentionPoolingHead
|
||||
dtype = next(iter(self.ipadapter_image_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_decoder,
|
||||
self.ipadapter_image_encoder,
|
||||
module_map = {
|
||||
SiglipVisionEmbeddings: AutoWrappedModule,
|
||||
SiglipEncoder: AutoWrappedModule,
|
||||
SiglipMultiheadAttentionPoolingHead: AutoWrappedModule,
|
||||
torch.nn.MultiheadAttention: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -277,14 +338,25 @@ class FluxImagePipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae_encoder is not None:
|
||||
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
||||
if self.qwenvl is not None:
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VisionPatchEmbed, Qwen2_5_VLVisionBlock, Qwen2_5_VLPatchMerger,
|
||||
Qwen2_5_VLDecoderLayer, Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
|
||||
)
|
||||
dtype = next(iter(self.qwenvl.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_encoder,
|
||||
self.qwenvl,
|
||||
module_map = {
|
||||
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
|
||||
Qwen2_5_VLVisionBlock: AutoWrappedModule,
|
||||
Qwen2_5_VLPatchMerger: AutoWrappedModule,
|
||||
Qwen2_5_VLDecoderLayer: AutoWrappedModule,
|
||||
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
|
||||
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
|
||||
Qwen2RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -303,16 +375,12 @@ class FluxImagePipeline(BasePipeline):
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
|
||||
local_model_path: str = "./models",
|
||||
skip_download: bool = False,
|
||||
redirect_common_files: bool = True,
|
||||
use_usp=False,
|
||||
nexus_gen_processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"),
|
||||
):
|
||||
# Download and load models
|
||||
model_manager = ModelManager()
|
||||
for model_config in model_configs:
|
||||
model_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
||||
model_config.download_if_necessary()
|
||||
model_manager.load_model(
|
||||
model_config.path,
|
||||
device=model_config.offload_device or device,
|
||||
@@ -335,13 +403,29 @@ class FluxImagePipeline(BasePipeline):
|
||||
if pipe.image_proj_model is not None:
|
||||
pipe.infinityou_processor = InfinitYou(device=device)
|
||||
pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
|
||||
pipe.lora_encoder = model_manager.fetch_model("flux_lora_encoder")
|
||||
pipe.nexus_gen = model_manager.fetch_model("nexus_gen_llm")
|
||||
pipe.nexus_gen_generation_adapter = model_manager.fetch_model("nexus_gen_generation_adapter")
|
||||
pipe.nexus_gen_editing_adapter = model_manager.fetch_model("nexus_gen_editing_adapter")
|
||||
if nexus_gen_processor_config is not None and pipe.nexus_gen is not None:
|
||||
nexus_gen_processor_config.download_if_necessary()
|
||||
pipe.nexus_gen.load_processor(nexus_gen_processor_config.path)
|
||||
|
||||
# ControlNet
|
||||
controlnets = []
|
||||
for model_name, model in zip(model_manager.model_name, model_manager.model):
|
||||
if model_name == "flux_controlnet":
|
||||
controlnets.append(model)
|
||||
pipe.controlnet = MultiControlNet(controlnets)
|
||||
if len(controlnets) > 0:
|
||||
pipe.controlnet = MultiControlNet(controlnets)
|
||||
|
||||
# Value Controller
|
||||
value_controllers = []
|
||||
for model_name, model in zip(model_manager.model_name, model_manager.model):
|
||||
if model_name == "flux_value_controller":
|
||||
value_controllers.append(model)
|
||||
if len(value_controllers) > 0:
|
||||
pipe.value_controller = MultiValueEncoder(value_controllers)
|
||||
|
||||
return pipe
|
||||
|
||||
@@ -393,8 +477,15 @@ class FluxImagePipeline(BasePipeline):
|
||||
flex_control_image: Image.Image = None,
|
||||
flex_control_strength: float = 0.5,
|
||||
flex_control_stop: float = 0.5,
|
||||
# Value Controller
|
||||
value_controller_inputs: Union[list[float], float] = None,
|
||||
# Step1x
|
||||
step1x_reference_image: Image.Image = None,
|
||||
# NexusGen
|
||||
nexus_gen_reference_image: Image.Image = None,
|
||||
# LoRA Encoder
|
||||
lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None,
|
||||
lora_encoder_scale: float = 1.0,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh: float = None,
|
||||
# Tile
|
||||
@@ -426,7 +517,10 @@ class FluxImagePipeline(BasePipeline):
|
||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint,
|
||||
"infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance,
|
||||
"flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
|
||||
"value_controller_inputs": value_controller_inputs,
|
||||
"step1x_reference_image": step1x_reference_image,
|
||||
"nexus_gen_reference_image": nexus_gen_reference_image,
|
||||
"lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"progress_bar_cmd": progress_bar_cmd,
|
||||
@@ -677,15 +771,70 @@ class FluxImageUnit_EntityControl(PipelineUnit):
|
||||
if eligen_entity_prompts is None or eligen_entity_masks is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False)
|
||||
eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,
|
||||
eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"],
|
||||
inputs_shared["t5_sequence_length"], inputs_shared["eligen_enable_on_negative"], inputs_shared["cfg_scale"])
|
||||
inputs_shared["t5_sequence_length"], eligen_enable_on_negative, inputs_shared["cfg_scale"])
|
||||
inputs_posi.update(eligen_kwargs_posi)
|
||||
if inputs_shared.get("cfg_scale", 1.0) != 1.0:
|
||||
inputs_nega.update(eligen_kwargs_nega)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class FluxImageUnit_NexusGen(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"),
|
||||
)
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if pipe.nexus_gen is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if inputs_shared.get("nexus_gen_reference_image", None) is None:
|
||||
assert pipe.nexus_gen_generation_adapter is not None, "NexusGen requires a generation adapter to be set."
|
||||
embed = pipe.nexus_gen(inputs_posi["prompt"])[0].unsqueeze(0)
|
||||
inputs_posi["prompt_emb"] = pipe.nexus_gen_generation_adapter(embed)
|
||||
inputs_posi['text_ids'] = torch.zeros(embed.shape[0], embed.shape[1], 3).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
else:
|
||||
assert pipe.nexus_gen_editing_adapter is not None, "NexusGen requires an editing adapter to be set."
|
||||
embed, ref_embed, grids = pipe.nexus_gen(inputs_posi["prompt"], inputs_shared["nexus_gen_reference_image"])
|
||||
embeds_grid = grids[1:2].to(device=pipe.device, dtype=torch.long)
|
||||
ref_embeds_grid = grids[0:1].to(device=pipe.device, dtype=torch.long)
|
||||
|
||||
inputs_posi["prompt_emb"] = pipe.nexus_gen_editing_adapter(embed.unsqueeze(0), embeds_grid, ref_embed.unsqueeze(0), ref_embeds_grid)
|
||||
inputs_posi["text_ids"] = self.get_editing_text_ids(
|
||||
inputs_shared["latents"],
|
||||
embeds_grid[0][1].item(), embeds_grid[0][2].item(),
|
||||
ref_embeds_grid[0][1].item(), ref_embeds_grid[0][2].item(),
|
||||
)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
def get_editing_text_ids(self, latents, target_embed_height, target_embed_width, ref_embed_height, ref_embed_width):
|
||||
# prepare text ids for target and reference embeddings
|
||||
batch_size, height, width = latents.shape[0], target_embed_height, target_embed_width
|
||||
embed_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width
|
||||
embed_ids[..., 1] = embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height
|
||||
embed_ids[..., 2] = embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width
|
||||
embed_ids = embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3)
|
||||
embed_text_ids = embed_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
batch_size, height, width = latents.shape[0], ref_embed_height, ref_embed_width
|
||||
ref_embed_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width
|
||||
ref_embed_ids[..., 0] = ref_embed_ids[..., 0] + 1.0
|
||||
ref_embed_ids[..., 1] = ref_embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height
|
||||
ref_embed_ids[..., 2] = ref_embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width
|
||||
ref_embed_ids = ref_embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3)
|
||||
ref_embed_text_ids = ref_embed_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
text_ids = torch.cat([embed_text_ids, ref_embed_text_ids], dim=1)
|
||||
return text_ids
|
||||
|
||||
|
||||
class FluxImageUnit_Step1x(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder"))
|
||||
@@ -704,7 +853,8 @@ class FluxImageUnit_Step1x(PipelineUnit):
|
||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
image = pipe.vae_encoder(image)
|
||||
inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image})
|
||||
inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image})
|
||||
if inputs_shared.get("cfg_scale", 1) != 1:
|
||||
inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image})
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
@@ -723,10 +873,12 @@ class FluxImageUnit_Flex(PipelineUnit):
|
||||
super().__init__(
|
||||
input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
)
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride):
|
||||
if pipe.dit.input_dim == 196:
|
||||
if flex_control_stop is None:
|
||||
flex_control_stop = 1
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if flex_inpaint_image is None:
|
||||
flex_inpaint_image = torch.zeros_like(latents)
|
||||
@@ -756,18 +908,53 @@ class FluxImageUnit_Flex(PipelineUnit):
|
||||
|
||||
class FluxImageUnit_InfiniteYou(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("infinityou_id_image", "infinityou_guidance"))
|
||||
super().__init__(
|
||||
input_params=("infinityou_id_image", "infinityou_guidance"),
|
||||
onload_model_names=("infinityou_processor",)
|
||||
)
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, infinityou_id_image, infinityou_guidance):
|
||||
pipe.load_models_to_device("infinityou_processor")
|
||||
if infinityou_id_image is not None:
|
||||
return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance)
|
||||
return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance, pipe.device)
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
class InfinitYou:
|
||||
class FluxImageUnit_ValueControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
|
||||
input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
|
||||
input_params=("value_controller_inputs",),
|
||||
onload_model_names=("value_controller",)
|
||||
)
|
||||
|
||||
def add_to_text_embedding(self, prompt_emb, text_ids, value_emb):
|
||||
prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
|
||||
extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
|
||||
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
|
||||
return prompt_emb, text_ids
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs):
|
||||
if value_controller_inputs is None:
|
||||
return {}
|
||||
if not isinstance(value_controller_inputs, list):
|
||||
value_controller_inputs = [value_controller_inputs]
|
||||
value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
pipe.load_models_to_device(["value_controller"])
|
||||
value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)
|
||||
value_emb = value_emb.unsqueeze(0)
|
||||
prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb)
|
||||
return {"prompt_emb": prompt_emb, "text_ids": text_ids}
|
||||
|
||||
|
||||
|
||||
class InfinitYou(torch.nn.Module):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__()
|
||||
from facexlib.recognition import init_recognition_model
|
||||
from insightface.app import FaceAnalysis
|
||||
self.device = device
|
||||
@@ -779,7 +966,7 @@ class InfinitYou:
|
||||
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
|
||||
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
|
||||
self.arcface_model = init_recognition_model('arcface', device=self.device)
|
||||
self.arcface_model = init_recognition_model('arcface', device=self.device).to(torch_dtype)
|
||||
|
||||
def _detect_face(self, id_image_cv2):
|
||||
face_info = self.app_640.get(id_image_cv2)
|
||||
@@ -791,16 +978,16 @@ class InfinitYou:
|
||||
face_info = self.app_160.get(id_image_cv2)
|
||||
return face_info
|
||||
|
||||
def extract_arcface_bgr_embedding(self, in_image, landmark):
|
||||
def extract_arcface_bgr_embedding(self, in_image, landmark, device):
|
||||
from insightface.utils import face_align
|
||||
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
|
||||
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
|
||||
arc_face_image = 2 * arc_face_image - 1
|
||||
arc_face_image = arc_face_image.contiguous().to(self.device)
|
||||
arc_face_image = arc_face_image.contiguous().to(device=device, dtype=self.torch_dtype)
|
||||
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
|
||||
return face_emb
|
||||
|
||||
def prepare_infinite_you(self, model, id_image, infinityou_guidance):
|
||||
def prepare_infinite_you(self, model, id_image, infinityou_guidance, device):
|
||||
import cv2
|
||||
if id_image is None:
|
||||
return {'id_emb': None}
|
||||
@@ -809,12 +996,72 @@ class InfinitYou:
|
||||
if len(face_info) == 0:
|
||||
raise ValueError('No face detected in the input ID image')
|
||||
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
|
||||
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
|
||||
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark, device)
|
||||
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
|
||||
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
|
||||
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=device, dtype=self.torch_dtype)
|
||||
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}
|
||||
|
||||
|
||||
|
||||
class FluxImageUnit_LoRAEncode(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("lora_encoder",)
|
||||
)
|
||||
|
||||
def parse_lora_encoder_inputs(self, lora_encoder_inputs):
|
||||
if not isinstance(lora_encoder_inputs, list):
|
||||
lora_encoder_inputs = [lora_encoder_inputs]
|
||||
lora_configs = []
|
||||
for lora_encoder_input in lora_encoder_inputs:
|
||||
if isinstance(lora_encoder_input, str):
|
||||
lora_encoder_input = ModelConfig(path=lora_encoder_input)
|
||||
lora_encoder_input.download_if_necessary()
|
||||
lora_configs.append(lora_encoder_input)
|
||||
return lora_configs
|
||||
|
||||
def load_lora(self, lora_config, dtype, device):
|
||||
loader = FluxLoRALoader(torch_dtype=dtype, device=device)
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=dtype, device=device)
|
||||
lora = loader.convert_state_dict(lora)
|
||||
return lora
|
||||
|
||||
def lora_embedding(self, pipe, lora_encoder_inputs):
|
||||
lora_emb = []
|
||||
for lora_config in self.parse_lora_encoder_inputs(lora_encoder_inputs):
|
||||
lora = self.load_lora(lora_config, pipe.torch_dtype, pipe.device)
|
||||
lora_emb.append(pipe.lora_encoder(lora))
|
||||
lora_emb = torch.concat(lora_emb, dim=1)
|
||||
return lora_emb
|
||||
|
||||
def add_to_text_embedding(self, prompt_emb, text_ids, lora_emb):
|
||||
prompt_emb = torch.concat([prompt_emb, lora_emb], dim=1)
|
||||
extra_text_ids = torch.zeros((lora_emb.shape[0], lora_emb.shape[1], 3), device=lora_emb.device, dtype=lora_emb.dtype)
|
||||
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
|
||||
return prompt_emb, text_ids
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if inputs_shared.get("lora_encoder_inputs", None) is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
# Encode
|
||||
pipe.load_models_to_device(["lora_encoder"])
|
||||
lora_encoder_inputs = inputs_shared["lora_encoder_inputs"]
|
||||
lora_emb = self.lora_embedding(pipe, lora_encoder_inputs)
|
||||
|
||||
# Scale
|
||||
lora_encoder_scale = inputs_shared.get("lora_encoder_scale", None)
|
||||
if lora_encoder_scale is not None:
|
||||
lora_emb = lora_emb * lora_encoder_scale
|
||||
|
||||
# Add to prompt embedding
|
||||
inputs_posi["prompt_emb"], inputs_posi["text_ids"] = self.add_to_text_embedding(
|
||||
inputs_posi["prompt_emb"], inputs_posi["text_ids"], lora_emb)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
@@ -984,6 +1231,7 @@ def model_fn_flux_image(
|
||||
|
||||
hidden_states = dit.x_embedder(hidden_states)
|
||||
|
||||
# EliGen
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
else:
|
||||
|
||||
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
||||
from typing import Optional
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
|
||||
@@ -26,194 +27,6 @@ from ..lora import GeneralLoRALoader
|
||||
|
||||
|
||||
|
||||
class BasePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cuda", torch_dtype=torch.float16,
|
||||
height_division_factor=64, width_division_factor=64,
|
||||
time_division_factor=None, time_division_remainder=None,
|
||||
):
|
||||
super().__init__()
|
||||
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# The following parameters are used for shape check.
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
self.vram_management_enabled = False
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def check_resize_height_width(self, height, width, num_frames=None):
|
||||
# Shape check
|
||||
if height % self.height_division_factor != 0:
|
||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||
if width % self.width_division_factor != 0:
|
||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||
if num_frames is None:
|
||||
return height, width
|
||||
else:
|
||||
if num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
||||
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||
return height, width, num_frames
|
||||
|
||||
|
||||
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
||||
# Transform a PIL.Image to torch.Tensor
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32))
|
||||
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||
image = image * ((max_value - min_value) / 255) + min_value
|
||||
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
||||
# Transform a list of PIL.Image to torch.Tensor
|
||||
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
||||
video = torch.stack(video, dim=pattern.index("T") // 2)
|
||||
return video
|
||||
|
||||
|
||||
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
||||
# Transform a torch.Tensor to PIL.Image
|
||||
if pattern != "H W C":
|
||||
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
||||
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
||||
image = image.to(device="cpu", dtype=torch.uint8)
|
||||
image = Image.fromarray(image.numpy())
|
||||
return image
|
||||
|
||||
|
||||
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
||||
# Transform a torch.Tensor to list of PIL.Image
|
||||
if pattern != "T H W C":
|
||||
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
||||
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
||||
return video
|
||||
|
||||
|
||||
def load_models_to_device(self, model_names=[]):
|
||||
if self.vram_management_enabled:
|
||||
# offload models
|
||||
for name, model in self.named_children():
|
||||
if name not in model_names:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "offload"):
|
||||
module.offload()
|
||||
else:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
# onload models
|
||||
for name, model in self.named_children():
|
||||
if name in model_names:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "onload"):
|
||||
module.onload()
|
||||
else:
|
||||
model.to(self.device)
|
||||
|
||||
|
||||
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
||||
# Initialize Gaussian noise
|
||||
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
||||
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||
return noise
|
||||
|
||||
|
||||
def enable_cpu_offload(self):
|
||||
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
|
||||
self.vram_management_enabled = True
|
||||
|
||||
|
||||
def get_vram(self):
|
||||
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
||||
|
||||
|
||||
def freeze_except(self, model_names):
|
||||
for name, model in self.named_children():
|
||||
if name in model_names:
|
||||
model.train()
|
||||
model.requires_grad_(True)
|
||||
else:
|
||||
model.eval()
|
||||
model.requires_grad_(False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
path: Union[str, list[str]] = None
|
||||
model_id: str = None
|
||||
origin_file_pattern: Union[str, list[str]] = None
|
||||
download_resource: str = "ModelScope"
|
||||
offload_device: Optional[Union[str, torch.device]] = None
|
||||
offload_dtype: Optional[torch.dtype] = None
|
||||
|
||||
def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False):
|
||||
if self.path is None:
|
||||
# Check model_id and origin_file_pattern
|
||||
if self.model_id is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
|
||||
|
||||
# Skip if not in rank 0
|
||||
if use_usp:
|
||||
import torch.distributed as dist
|
||||
skip_download = dist.get_rank() != 0
|
||||
|
||||
# Check whether the origin path is a folder
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
self.origin_file_pattern = ""
|
||||
allow_file_pattern = None
|
||||
is_folder = True
|
||||
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
|
||||
allow_file_pattern = self.origin_file_pattern + "*"
|
||||
is_folder = True
|
||||
else:
|
||||
allow_file_pattern = self.origin_file_pattern
|
||||
is_folder = False
|
||||
|
||||
# Download
|
||||
if not skip_download:
|
||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(local_model_path, self.model_id),
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
|
||||
# Let rank 1, 2, ... wait for rank 0
|
||||
if use_usp:
|
||||
import torch.distributed as dist
|
||||
dist.barrier(device_ids=[dist.get_rank()])
|
||||
|
||||
# Return downloaded files
|
||||
if is_folder:
|
||||
self.path = os.path.join(local_model_path, self.model_id, self.origin_file_pattern)
|
||||
else:
|
||||
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
|
||||
if isinstance(self.path, list) and len(self.path) == 1:
|
||||
self.path = self.path[0]
|
||||
|
||||
|
||||
class WanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
|
||||
@@ -226,17 +39,21 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.text_encoder: WanTextEncoder = None
|
||||
self.image_encoder: WanImageEncoder = None
|
||||
self.dit: WanModel = None
|
||||
self.dit2: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.vace: VaceWanModel = None
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.units = [
|
||||
WanVideoUnit_ShapeChecker(),
|
||||
WanVideoUnit_NoiseInitializer(),
|
||||
WanVideoUnit_InputVideoEmbedder(),
|
||||
WanVideoUnit_PromptEmbedder(),
|
||||
WanVideoUnit_ImageEmbedder(),
|
||||
WanVideoUnit_ImageEmbedderVAE(),
|
||||
WanVideoUnit_ImageEmbedderCLIP(),
|
||||
WanVideoUnit_ImageEmbedderFused(),
|
||||
WanVideoUnit_FunControl(),
|
||||
WanVideoUnit_FunReference(),
|
||||
WanVideoUnit_FunCameraControl(),
|
||||
@@ -256,7 +73,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps)
|
||||
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
|
||||
@@ -328,6 +147,37 @@ class WanVideoPipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.dit2 is not None:
|
||||
dtype = next(iter(self.dit2.parameters())).dtype
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
self.dit2,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae is not None:
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
enable_vram_management(
|
||||
@@ -426,6 +276,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
for block in self.dit.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
|
||||
if self.dit2 is not None:
|
||||
for block in self.dit2.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
||||
self.sp_size = get_sequence_parallel_world_size()
|
||||
self.use_unified_sequence_parallel = True
|
||||
|
||||
@@ -436,8 +290,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
|
||||
local_model_path: str = "./models",
|
||||
skip_download: bool = False,
|
||||
redirect_common_files: bool = True,
|
||||
use_usp=False,
|
||||
):
|
||||
@@ -462,7 +314,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Download and load models
|
||||
model_manager = ModelManager()
|
||||
for model_config in model_configs:
|
||||
model_config.download_if_necessary(local_model_path, skip_download=skip_download, use_usp=use_usp)
|
||||
model_config.download_if_necessary(use_usp=use_usp)
|
||||
model_manager.load_model(
|
||||
model_config.path,
|
||||
device=model_config.offload_device or device,
|
||||
@@ -471,14 +323,23 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
# Load models
|
||||
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
||||
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
||||
dit = model_manager.fetch_model("wan_video_dit", index=2)
|
||||
if isinstance(dit, list):
|
||||
pipe.dit, pipe.dit2 = dit
|
||||
else:
|
||||
pipe.dit = dit
|
||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
||||
|
||||
# Size division factor
|
||||
if pipe.vae is not None:
|
||||
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
|
||||
pipe.width_division_factor = pipe.vae.upsampling_factor * 2
|
||||
|
||||
# Initialize tokenizer
|
||||
tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
||||
tokenizer_config.download_if_necessary(use_usp=use_usp)
|
||||
pipe.prompter.fetch_models(pipe.text_encoder)
|
||||
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
||||
|
||||
@@ -522,6 +383,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 5.0,
|
||||
cfg_merge: Optional[bool] = False,
|
||||
# Boundary
|
||||
switch_DiT_boundary: Optional[float] = 0.875,
|
||||
# Scheduler
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
sigma_shift: Optional[float] = 5.0,
|
||||
@@ -574,8 +437,14 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
# Switch DiT if necessary
|
||||
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
|
||||
self.load_models_to_device(self.in_iteration_models_2)
|
||||
models["dit"] = self.dit2
|
||||
|
||||
# Timestep
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
|
||||
if cfg_scale != 1.0:
|
||||
@@ -589,6 +458,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
# Scheduler
|
||||
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
|
||||
if "first_frame_latents" in inputs_shared:
|
||||
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
|
||||
|
||||
# VACE (TODO: remove it)
|
||||
if vace_reference_image is not None:
|
||||
@@ -604,63 +475,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
|
||||
|
||||
class PipelineUnit:
|
||||
def __init__(
|
||||
self,
|
||||
seperate_cfg: bool = False,
|
||||
take_over: bool = False,
|
||||
input_params: tuple[str] = None,
|
||||
input_params_posi: dict[str, str] = None,
|
||||
input_params_nega: dict[str, str] = None,
|
||||
onload_model_names: tuple[str] = None
|
||||
):
|
||||
self.seperate_cfg = seperate_cfg
|
||||
self.take_over = take_over
|
||||
self.input_params = input_params
|
||||
self.input_params_posi = input_params_posi
|
||||
self.input_params_nega = input_params_nega
|
||||
self.onload_model_names = onload_model_names
|
||||
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, inputs: dict, positive=True, **kwargs) -> dict:
|
||||
raise NotImplementedError("`process` is not implemented.")
|
||||
|
||||
|
||||
|
||||
class PipelineUnitRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, unit: PipelineUnit, pipe: WanVideoPipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
||||
if unit.take_over:
|
||||
# Let the pipeline unit take over this function.
|
||||
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
||||
elif unit.seperate_cfg:
|
||||
# Positive side
|
||||
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_posi.update(processor_outputs)
|
||||
# Negative side
|
||||
if inputs_shared["cfg_scale"] != 1:
|
||||
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_shared.update(processor_outputs)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_ShapeChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("height", "width", "num_frames"))
|
||||
@@ -679,7 +493,8 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
length = (num_frames - 1) // 4 + 1
|
||||
if vace_reference_image is not None:
|
||||
length += 1
|
||||
noise = pipe.generate_noise((1, 16, length, height//8, width//8), seed=seed, rand_device=rand_device)
|
||||
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
|
||||
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
|
||||
if vace_reference_image is not None:
|
||||
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
|
||||
return {"noise": noise}
|
||||
@@ -728,6 +543,9 @@ class WanVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
|
||||
|
||||
class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||
"""
|
||||
Deprecated
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
@@ -735,7 +553,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
if input_image is None:
|
||||
if input_image is None or pipe.image_encoder is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
@@ -763,13 +581,90 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"clip_feature": clip_context, "y": y}
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "end_image", "height", "width"),
|
||||
onload_model_names=("image_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width):
|
||||
if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
clip_context = pipe.image_encoder.encode_image([image])
|
||||
if end_image is not None:
|
||||
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
||||
if pipe.dit.has_image_pos_emb:
|
||||
clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
|
||||
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"clip_feature": clip_context}
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
if input_image is None or not pipe.dit.require_vae_embedding:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||
msk[:, 1:] = 0
|
||||
if end_image is not None:
|
||||
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
||||
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||
msk[:, -1:] = 1
|
||||
else:
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
|
||||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"y": y}
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
|
||||
"""
|
||||
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride):
|
||||
if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1)
|
||||
z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
latents[:, :, 0: 1] = z
|
||||
return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z}
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_FunControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
|
||||
onload_model_names=("vae")
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
|
||||
@@ -793,7 +688,7 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("reference_image", "height", "width", "reference_image"),
|
||||
onload_model_names=("vae")
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
|
||||
@@ -812,7 +707,8 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
||||
class WanVideoUnit_FunCameraControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image")
|
||||
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image):
|
||||
@@ -835,6 +731,7 @@ class WanVideoUnit_FunCameraControl(PipelineUnit):
|
||||
|
||||
input_image = input_image.resize((width, height))
|
||||
input_latents = pipe.preprocess_video([input_image])
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
|
||||
y = torch.zeros_like(latents).to(pipe.device)
|
||||
y[:, :, :1] = input_latents
|
||||
@@ -1014,10 +911,14 @@ class TemporalTiler_BCTHW:
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if border_width == 0:
|
||||
return x
|
||||
|
||||
shift = 0.5
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
x[:border_width] = (torch.arange(border_width) + shift) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
@@ -1078,6 +979,7 @@ def model_fn_wan_video(
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
control_camera_latents_input = None,
|
||||
fuse_vae_embedding_in_latents: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if sliding_window_size is not None and sliding_window_stride is not None:
|
||||
@@ -1111,9 +1013,20 @@ def model_fn_wan_video(
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
|
||||
# Timestep
|
||||
if dit.seperated_timestep and fuse_vae_embedding_in_latents:
|
||||
timestep = torch.concat([
|
||||
torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device),
|
||||
torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
|
||||
]).flatten()
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
|
||||
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
|
||||
else:
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||
# Motion Controller
|
||||
if motion_bucket_id is not None and motion_controller is not None:
|
||||
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
||||
context = dit.text_embedding(context)
|
||||
@@ -1124,15 +1037,16 @@ def model_fn_wan_video(
|
||||
x = torch.concat([x] * context.shape[0], dim=0)
|
||||
if timestep.shape[0] != context.shape[0]:
|
||||
timestep = torch.concat([timestep] * context.shape[0], dim=0)
|
||||
|
||||
if dit.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
|
||||
# Image Embedding
|
||||
if y is not None and dit.require_vae_embedding:
|
||||
x = torch.cat([x, y], dim=1)
|
||||
if clip_feature is not None and dit.require_clip_embedding:
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
# Add camera control
|
||||
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
||||
|
||||
|
||||
# Reference image
|
||||
if reference_latents is not None:
|
||||
|
||||
@@ -120,8 +120,12 @@ class ImageDataset(torch.utils.data.Dataset):
|
||||
data = self.data[data_id % len(self.data)].copy()
|
||||
for key in self.data_file_keys:
|
||||
if key in data:
|
||||
path = os.path.join(self.base_path, data[key])
|
||||
data[key] = self.load_data(path)
|
||||
if isinstance(data[key], list):
|
||||
path = [os.path.join(self.base_path, p) for p in data[key]]
|
||||
data[key] = [self.load_data(p) for p in path]
|
||||
else:
|
||||
path = os.path.join(self.base_path, data[key])
|
||||
data[key] = self.load_data(path)
|
||||
if data[key] is None:
|
||||
warnings.warn(f"cannot load file {data[key]}.")
|
||||
return None
|
||||
@@ -434,6 +438,8 @@ def wan_parser():
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
261
diffsynth/utils/__init__.py
Normal file
261
diffsynth/utils/__init__.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import torch, warnings, glob, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import repeat, reduce
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from modelscope import snapshot_download
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BasePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cuda", torch_dtype=torch.float16,
|
||||
height_division_factor=64, width_division_factor=64,
|
||||
time_division_factor=None, time_division_remainder=None,
|
||||
):
|
||||
super().__init__()
|
||||
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# The following parameters are used for shape check.
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
self.vram_management_enabled = False
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def check_resize_height_width(self, height, width, num_frames=None):
|
||||
# Shape check
|
||||
if height % self.height_division_factor != 0:
|
||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||
if width % self.width_division_factor != 0:
|
||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||
if num_frames is None:
|
||||
return height, width
|
||||
else:
|
||||
if num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
||||
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||
return height, width, num_frames
|
||||
|
||||
|
||||
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
||||
# Transform a PIL.Image to torch.Tensor
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32))
|
||||
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||
image = image * ((max_value - min_value) / 255) + min_value
|
||||
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
||||
# Transform a list of PIL.Image to torch.Tensor
|
||||
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
||||
video = torch.stack(video, dim=pattern.index("T") // 2)
|
||||
return video
|
||||
|
||||
|
||||
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
||||
# Transform a torch.Tensor to PIL.Image
|
||||
if pattern != "H W C":
|
||||
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
||||
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
||||
image = image.to(device="cpu", dtype=torch.uint8)
|
||||
image = Image.fromarray(image.numpy())
|
||||
return image
|
||||
|
||||
|
||||
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
||||
# Transform a torch.Tensor to list of PIL.Image
|
||||
if pattern != "T H W C":
|
||||
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
||||
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
||||
return video
|
||||
|
||||
|
||||
def load_models_to_device(self, model_names=[]):
|
||||
if self.vram_management_enabled:
|
||||
# offload models
|
||||
for name, model in self.named_children():
|
||||
if name not in model_names:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "offload"):
|
||||
module.offload()
|
||||
else:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
# onload models
|
||||
for name, model in self.named_children():
|
||||
if name in model_names:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "onload"):
|
||||
module.onload()
|
||||
else:
|
||||
model.to(self.device)
|
||||
|
||||
|
||||
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
||||
# Initialize Gaussian noise
|
||||
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
||||
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||
return noise
|
||||
|
||||
|
||||
def enable_cpu_offload(self):
|
||||
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
|
||||
self.vram_management_enabled = True
|
||||
|
||||
|
||||
def get_vram(self):
|
||||
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
||||
|
||||
|
||||
def freeze_except(self, model_names):
|
||||
for name, model in self.named_children():
|
||||
if name in model_names:
|
||||
model.train()
|
||||
model.requires_grad_(True)
|
||||
else:
|
||||
model.eval()
|
||||
model.requires_grad_(False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
path: Union[str, list[str]] = None
|
||||
model_id: str = None
|
||||
origin_file_pattern: Union[str, list[str]] = None
|
||||
download_resource: str = "ModelScope"
|
||||
offload_device: Optional[Union[str, torch.device]] = None
|
||||
offload_dtype: Optional[torch.dtype] = None
|
||||
local_model_path: str = None
|
||||
skip_download: bool = False
|
||||
|
||||
def download_if_necessary(self, use_usp=False):
|
||||
if self.path is None:
|
||||
# Check model_id and origin_file_pattern
|
||||
if self.model_id is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
|
||||
|
||||
# Skip if not in rank 0
|
||||
if use_usp:
|
||||
import torch.distributed as dist
|
||||
skip_download = self.skip_download or dist.get_rank() != 0
|
||||
else:
|
||||
skip_download = self.skip_download
|
||||
|
||||
# Check whether the origin path is a folder
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
self.origin_file_pattern = ""
|
||||
allow_file_pattern = None
|
||||
is_folder = True
|
||||
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
|
||||
allow_file_pattern = self.origin_file_pattern + "*"
|
||||
is_folder = True
|
||||
else:
|
||||
allow_file_pattern = self.origin_file_pattern
|
||||
is_folder = False
|
||||
|
||||
# Download
|
||||
if not skip_download:
|
||||
if self.local_model_path is None:
|
||||
self.local_model_path = "./models"
|
||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
|
||||
# Let rank 1, 2, ... wait for rank 0
|
||||
if use_usp:
|
||||
import torch.distributed as dist
|
||||
dist.barrier(device_ids=[dist.get_rank()])
|
||||
|
||||
# Return downloaded files
|
||||
if is_folder:
|
||||
self.path = os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)
|
||||
else:
|
||||
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||
if isinstance(self.path, list) and len(self.path) == 1:
|
||||
self.path = self.path[0]
|
||||
|
||||
|
||||
|
||||
class PipelineUnit:
|
||||
def __init__(
|
||||
self,
|
||||
seperate_cfg: bool = False,
|
||||
take_over: bool = False,
|
||||
input_params: tuple[str] = None,
|
||||
input_params_posi: dict[str, str] = None,
|
||||
input_params_nega: dict[str, str] = None,
|
||||
onload_model_names: tuple[str] = None
|
||||
):
|
||||
self.seperate_cfg = seperate_cfg
|
||||
self.take_over = take_over
|
||||
self.input_params = input_params
|
||||
self.input_params_posi = input_params_posi
|
||||
self.input_params_nega = input_params_nega
|
||||
self.onload_model_names = onload_model_names
|
||||
|
||||
|
||||
def process(self, pipe: BasePipeline, inputs: dict, positive=True, **kwargs) -> dict:
|
||||
raise NotImplementedError("`process` is not implemented.")
|
||||
|
||||
|
||||
|
||||
class PipelineUnitRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
||||
if unit.take_over:
|
||||
# Let the pipeline unit take over this function.
|
||||
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
||||
elif unit.seperate_cfg:
|
||||
# Positive side
|
||||
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_posi.update(processor_outputs)
|
||||
# Negative side
|
||||
if inputs_shared["cfg_scale"] != 1:
|
||||
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_shared.update(processor_outputs)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
39
examples/CogVideoX/README.md
Normal file
39
examples/CogVideoX/README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# CogVideoX
|
||||
|
||||
### Example: Text-to-Video using CogVideoX-5B (Experimental)
|
||||
|
||||
See [cogvideo_text_to_video.py](cogvideo_text_to_video.py).
|
||||
|
||||
First, we generate a video using prompt "an astronaut riding a horse on Mars".
|
||||
|
||||
https://github.com/user-attachments/assets/4c91c1cd-e4a0-471a-bd8d-24d761262941
|
||||
|
||||
Then, we convert the astronaut to a robot.
|
||||
|
||||
https://github.com/user-attachments/assets/225a00a4-2bc8-4740-8e86-a64b460a29ec
|
||||
|
||||
Upscale the video using the model itself.
|
||||
|
||||
https://github.com/user-attachments/assets/c02cb30c-de60-473c-8242-32c67b3155ad
|
||||
|
||||
Make the video look smoother by interpolating frames.
|
||||
|
||||
https://github.com/user-attachments/assets/f0e465b4-45df-4435-ab10-7a084ca2b0a0
|
||||
|
||||
Here is another example.
|
||||
|
||||
First, we generate a video using prompt "a dog is running".
|
||||
|
||||
https://github.com/user-attachments/assets/e3696297-99f5-4d0c-a5ca-1d1566db85b4
|
||||
|
||||
Then, we add a blue collar to the dog.
|
||||
|
||||
https://github.com/user-attachments/assets/7ff22be7-4390-4d33-ae6c-53f6f056e18d
|
||||
|
||||
Upscale the video using the model itself.
|
||||
|
||||
https://github.com/user-attachments/assets/a909c32c-0b7d-495c-a53c-d23a99a3d3e9
|
||||
|
||||
Make the video look smoother by interpolating frames.
|
||||
|
||||
https://github.com/user-attachments/assets/ea37c150-97a0-4858-8003-0c2e5eef3331
|
||||
@@ -18,7 +18,7 @@ pip install -e .
|
||||
|
||||
## Quick Start
|
||||
|
||||
You can quickly load the FLUX.1-dev model and perform inference by running the following code:
|
||||
You can quickly load the [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev ) model and run inference by executing the code below.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -41,12 +41,21 @@ image.save("image.jpg")
|
||||
|
||||
## Model Overview
|
||||
|
||||
**Support for the new framework of the FLUX series models is under active development. Stay tuned!**
|
||||
|
||||
| Model ID | Additional Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.1-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|Model ID|Extra Args|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_inference_low_vram/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||
|[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||
|[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||
|[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||
|[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||
|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./model_inference/FLUX.1-dev-EliGen.py)|[code](./model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||
|[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||
|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||
|[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./model_inference/Step1X-Edit.py)|[code](./model_inference_low_vram/Step1X-Edit.py)|[code](./model_training/full/Step1X-Edit.sh)|[code](./model_training/validate_full/Step1X-Edit.py)|[code](./model_training/lora/Step1X-Edit.sh)|[code](./model_training/validate_lora/Step1X-Edit.py)|
|
||||
|[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./model_inference/FLEX.2-preview.py)|[code](./model_inference_low_vram/FLEX.2-preview.py)|[code](./model_training/full/FLEX.2-preview.sh)|[code](./model_training/validate_full/FLEX.2-preview.py)|[code](./model_training/lora/FLEX.2-preview.sh)|[code](./model_training/validate_lora/FLEX.2-preview.py)|
|
||||
|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./model_inference/Nexus-Gen-Editing.py)|[code](./model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./model_training/full/Nexus-Gen.sh)|[code](./model_training/validate_full/Nexus-Gen.py)|[code](./model_training/lora/Nexus-Gen.sh)|[code](./model_training/validate_lora/Nexus-Gen.py)|
|
||||
|
||||
## Model Inference
|
||||
|
||||
@@ -54,11 +63,14 @@ The following sections will help you understand our features and write inference
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Loading Models</summary>
|
||||
<summary>Load Model</summary>
|
||||
|
||||
Models are loaded using `from_pretrained`:
|
||||
The model is loaded using `from_pretrained`:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
@@ -71,21 +83,21 @@ pipe = FluxImagePipeline.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
Here, `torch_dtype` and `device` refer to the computation precision and device, respectively. The `model_configs` can be configured in various ways to specify model paths:
|
||||
Here, `torch_dtype` and `device` set the computation precision and device. The `model_configs` can be used in different ways to specify model paths:
|
||||
|
||||
* Download the model from [ModelScope Community](https://modelscope.cn/) and load it. In this case, provide `model_id` and `origin_file_pattern`, for example:
|
||||
* Download the model from [ModelScope](https://modelscope.cn/ ) and load it. In this case, fill in `model_id` and `origin_file_pattern`, for example:
|
||||
|
||||
```python
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
||||
```
|
||||
|
||||
* Load the model from a local file path. In this case, provide the `path`, for example:
|
||||
* Load the model from a local file path. In this case, fill in `path`, for example:
|
||||
|
||||
```python
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors")
|
||||
```
|
||||
|
||||
For models that consist of multiple files, use a list as follows:
|
||||
For a single model that loads from multiple files, use a list, for example:
|
||||
|
||||
```python
|
||||
ModelConfig(path=[
|
||||
@@ -95,10 +107,10 @@ ModelConfig(path=[
|
||||
])
|
||||
```
|
||||
|
||||
The `from_pretrained` method also provides additional parameters to control model loading behavior:
|
||||
The `ModelConfig` method also provides extra arguments to control model loading behavior:
|
||||
|
||||
* `local_model_path`: Path for saving downloaded models. The default is `"./models"`.
|
||||
* `skip_download`: Whether to skip downloading models. The default is `False`. If your network cannot access [ModelScope Community](https://modelscope.cn/), manually download the required files and set this to `True`.
|
||||
* `local_model_path`: Path to save downloaded models. Default is `"./models"`.
|
||||
* `skip_download`: Whether to skip downloading. Default is `False`. If your network cannot access [ModelScope](https://modelscope.cn/ ), download the required files manually and set this to `True`.
|
||||
|
||||
</details>
|
||||
|
||||
@@ -107,7 +119,7 @@ The `from_pretrained` method also provides additional parameters to control mode
|
||||
|
||||
<summary>VRAM Management</summary>
|
||||
|
||||
DiffSynth-Studio provides fine-grained VRAM management for FLUX models, enabling inference on devices with limited VRAM. You can enable offloading functionality via the following code, which moves certain modules to system memory on devices with limited GPU memory.
|
||||
DiffSynth-Studio provides fine-grained VRAM management for the FLUX model. This allows the model to run on devices with low VRAM. You can enable the offload feature using the code below. It moves some modules to CPU memory when GPU memory is limited.
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
@@ -123,19 +135,52 @@ pipe = FluxImagePipeline.from_pretrained(
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
The `enable_vram_management` function provides the following parameters to control VRAM usage:
|
||||
FP8 quantization is also supported:
|
||||
|
||||
* `vram_limit`: VRAM usage limit in GB. By default, it uses the remaining VRAM available on the device. Note that this is not an absolute limit; if the set VRAM is insufficient but more VRAM is actually available, the model will run with minimal VRAM consumption. Setting it to 0 achieves the theoretical minimum VRAM usage.
|
||||
* `vram_buffer`: VRAM buffer size in GB. The default is 0.5GB. Since some large neural network layers may consume extra VRAM during onload phases, a VRAM buffer is necessary. Ideally, the optimal value should match the VRAM occupied by the largest layer in the model.
|
||||
* `num_persistent_param_in_dit`: Number of persistent parameters in the DiT model (default: no limit). We plan to remove this parameter in the future, so please avoid relying on it.
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
You can use FP8 quantization and offload at the same time:
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
After enabling VRAM management, the framework will automatically decide the VRAM strategy based on available GPU memory. For most FLUX models, inference can run with as little as 8GB of VRAM. The `enable_vram_management` function has the following parameters to manually control the VRAM strategy:
|
||||
|
||||
* `vram_limit`: VRAM usage limit in GB. By default, it uses all free VRAM on the device. Note that this is not an absolute limit. If the set VRAM is not enough but more VRAM is actually available, the model will run with minimal VRAM usage. Setting it to 0 achieves the theoretical minimum VRAM usage.
|
||||
* `vram_buffer`: VRAM buffer size in GB. Default is 0.5GB. A buffer is needed because larger neural network layers may use more VRAM than expected during loading. The optimal value is the VRAM used by the largest layer in the model.
|
||||
* `num_persistent_param_in_dit`: Number of parameters in the DiT model that stay in VRAM. Default is no limit. We plan to remove this parameter in the future. Do not rely on it.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Inference Acceleration</summary>
|
||||
|
||||
* TeaCache: Acceleration technique [TeaCache](https://github.com/ali-vilab/TeaCache), please refer to the [sample code](./acceleration/teacache.py).
|
||||
* TeaCache: Acceleration technique [TeaCache](https://github.com/ali-vilab/TeaCache ). Please refer to the [example code](./acceleration/teacache.py).
|
||||
|
||||
</details>
|
||||
|
||||
@@ -143,75 +188,98 @@ The `enable_vram_management` function provides the following parameters to contr
|
||||
|
||||
<summary>Input Parameters</summary>
|
||||
|
||||
The pipeline accepts the following input parameters during inference:
|
||||
The pipeline supports the following input parameters during inference:
|
||||
|
||||
* `prompt`: Prompt describing what should appear in the image.
|
||||
* `negative_prompt`: Negative prompt describing what should **not** appear in the image. Default is `""`.
|
||||
* `cfg_scale`: Classifier-free guidance scale. Default is 1. It becomes effective when set to a value greater than 1.
|
||||
* `embedded_guidance`: Embedded guidance parameter for FLUX-dev. Default is 3.5.
|
||||
* `t5_sequence_length`: Sequence length of T5 text embeddings. Default is 512.
|
||||
* `input_image`: Input image used for image-to-image generation. This works together with `denoising_strength`.
|
||||
* `denoising_strength`: Denoising strength, ranging from 0 to 1. Default is 1. When close to 0, the generated image will be similar to the input image; when close to 1, the generated image will differ significantly from the input. Do not set this to a non-1 value if no `input_image` is provided.
|
||||
* `height`: Height of the generated image. Must be a multiple of 16.
|
||||
* `width`: Width of the generated image. Must be a multiple of 16.
|
||||
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||
* `rand_device`: Device for generating random Gaussian noise. Default is `"cpu"`. Setting it to `"cuda"` may lead to different results across GPUs.
|
||||
* `sigma_shift`: Parameter from Rectified Flow theory. Default is 3. A larger value increases the number of steps spent at the beginning of denoising and can improve image quality. However, it may cause inconsistencies between the generation process and training data.
|
||||
* `prompt`: Text prompt describing what should appear in the image.
|
||||
* `negative_prompt`: Negative prompt describing what should not appear in the image. Default is `""`.
|
||||
* `cfg_scale`: Parameter for classifier-free guidance. Default is 1. Takes effect when set to a value greater than 1.
|
||||
* `embedded_guidance`: Built-in guidance parameter for FLUX-dev. Default is 3.5.
|
||||
* `t5_sequence_length`: Sequence length of text embeddings from the T5 model. Default is 512.
|
||||
* `input_image`: Input image used for image-to-image generation. Used together with `denoising_strength`.
|
||||
* `denoising_strength`: Denoising strength, range from 0 to 1. Default is 1. When close to 0, the output image is similar to the input. When close to 1, the output differs more from the input. Do not set it to values other than 1 if `input_image` is not provided.
|
||||
* `height`: Image height. Must be a multiple of 16.
|
||||
* `width`: Image width. Must be a multiple of 16.
|
||||
* `seed`: Random seed. Default is `None`, meaning fully random.
|
||||
* `rand_device`: Device for generating random Gaussian noise. Default is `"cpu"`. Setting it to `"cuda"` may lead to different results on different GPUs.
|
||||
* `sigma_shift`: Parameter from Rectified Flow theory. Default is 3. A larger value means the model spends more steps at the start of denoising. Increasing this can improve image quality, but may cause differences between generated images and training data due to inconsistency with training.
|
||||
* `num_inference_steps`: Number of inference steps. Default is 30.
|
||||
* `kontext_images`: Input images for the Kontext model.
|
||||
* `controlnet_inputs`: Inputs for the ControlNet model.
|
||||
* `ipadapter_images`: Input images for the IP-Adapter model.
|
||||
* `ipadapter_scale`: Control strength of the IP-Adapter model.
|
||||
* `ipadapter_scale`: Control strength for the IP-Adapter model.
|
||||
* `eligen_entity_prompts`: Local prompts for the EliGen model.
|
||||
* `eligen_entity_masks`: Mask regions for local prompts in the EliGen model. Matches one-to-one with `eligen_entity_prompts`.
|
||||
* `eligen_enable_on_negative`: Whether to enable EliGen on the negative prompt side. Only works when `cfg_scale > 1`.
|
||||
* `eligen_enable_inpaint`: Whether to enable EliGen for local inpainting.
|
||||
* `infinityou_id_image`: Face image for the InfiniteYou model.
|
||||
* `infinityou_guidance`: Control strength for the InfiniteYou model.
|
||||
* `flex_inpaint_image`: Image for FLEX model's inpainting.
|
||||
* `flex_inpaint_mask`: Mask region for FLEX model's inpainting.
|
||||
* `flex_control_image`: Image for FLEX model's structural control.
|
||||
* `flex_control_strength`: Strength for FLEX model's structural control.
|
||||
* `flex_control_stop`: End point for FLEX model's structural control. 1 means enabled throughout, 0.5 means enabled in the first half, 0 means disabled.
|
||||
* `step1x_reference_image`: Input image for Step1x-Edit model's image editing.
|
||||
* `lora_encoder_inputs`: Inputs for LoRA encoder. Can be ModelConfig or local path.
|
||||
* `lora_encoder_scale`: Activation strength for LoRA encoder. Default is 1. Smaller values mean weaker LoRA activation.
|
||||
* `tea_cache_l1_thresh`: Threshold for TeaCache. Larger values mean faster speed but lower image quality. Note that after enabling TeaCache, inference speed is not uniform, so the remaining time shown in the progress bar will be inaccurate.
|
||||
* `tiled`: Whether to enable tiled VAE inference. Default is `False`. Setting to `True` reduces VRAM usage during VAE encoding/decoding, with slight error and slightly longer inference time.
|
||||
* `tile_size`: Tile size during VAE encoding/decoding. Default is 128. Only takes effect when `tiled=True`.
|
||||
* `tile_stride`: Tile stride during VAE encoding/decoding. Default is 64. Only takes effect when `tiled=True`. Must be less than or equal to `tile_size`.
|
||||
* `progress_bar_cmd`: Progress bar display. Default is `tqdm.tqdm`. Set to `lambda x:x` to disable the progress bar.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## Model Training
|
||||
|
||||
FLUX series models are trained using a unified script [`./model_training/train.py`](./model_training/train.py).
|
||||
Training for the FLUX series models is done using a unified script [`./model_training/train.py`](./model_training/train.py).
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Script Parameters</summary>
|
||||
|
||||
The script supports the following parameters:
|
||||
The script includes the following parameters:
|
||||
|
||||
* Dataset
|
||||
* `--dataset_base_path`: Root path to the dataset.
|
||||
* `--dataset_metadata_path`: Path to the metadata file of the dataset.
|
||||
* `--max_pixels`: Maximum pixel area, default is 1024*1024. When dynamic resolution is enabled, any image with a resolution larger than this value will be scaled down.。
|
||||
* `--height`: Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--width`: Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--data_file_keys`: Keys in metadata for data files. Comma-separated.
|
||||
* `--dataset_base_path`: Root path of the dataset.
|
||||
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||
* `--max_pixels`: Maximum pixel area. Default is 1024*1024. When dynamic resolution is enabled, any image with resolution higher than this will be downscaled.
|
||||
* `--height`: Height of the image or video. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--data_file_keys`: Data file keys in the metadata. Separate with commas.
|
||||
* `--dataset_repeat`: Number of times the dataset repeats per epoch.
|
||||
* Models
|
||||
* `--model_paths`: Paths to load models. JSON format.
|
||||
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Comma-separated.
|
||||
* Model
|
||||
* `--model_paths`: Paths to load models. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas.
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of training epochs.
|
||||
* `--output_path`: Output path for saving checkpoints.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint filenames.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
|
||||
* Trainable Modules
|
||||
* `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Which base model to apply LoRA on.
|
||||
* `--lora_target_modules`: Which layers to apply LoRA on.
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* Extra Inputs
|
||||
* `--extra_inputs`: Additional model inputs. Comma-separated.
|
||||
* Extra Model Inputs
|
||||
* `--extra_inputs`: Extra model inputs, separated by commas.
|
||||
* VRAM Management
|
||||
* `--use_gradient_checkpointing`: Whether to use gradient checkpointing.
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||
* `--gradient_accumulation_steps`: Number of steps for gradient accumulation.
|
||||
* Miscellaneous
|
||||
* `--align_to_opensource_format`: Whether to align the FLUX DiT LoRA format with the open-source version. Only applicable to LoRA training for FLUX.1-dev and FLUX.1-Kontext-dev.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Others
|
||||
* `--align_to_opensource_format`: Whether to align the FLUX DiT LoRA format with the open-source version. Only works for LoRA training.
|
||||
|
||||
In addition, the training framework is built on [`accelerate`](https://huggingface.co/docs/accelerate/index ). Run `accelerate config` before training to set GPU-related parameters. For some training scripts (e.g., full model training), we provide suggested `accelerate` config files. You can find them in the corresponding training scripts.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 1: Prepare Dataset</summary>
|
||||
|
||||
The dataset contains a series of files. We recommend organizing your dataset files as follows:
|
||||
A dataset contains a series of files. We suggest organizing your dataset like this:
|
||||
|
||||
```
|
||||
data/example_image_dataset/
|
||||
@@ -220,7 +288,7 @@ data/example_image_dataset/
|
||||
└── image2.jpg
|
||||
```
|
||||
|
||||
Here, `image1.jpg`, `image2.jpg` are training image data, and `metadata.csv` is the metadata list, for example:
|
||||
Here, `image1.jpg` and `image2.jpg` are training images, and `metadata.csv` is the metadata list, for example:
|
||||
|
||||
```
|
||||
image,prompt
|
||||
@@ -228,7 +296,7 @@ image1.jpg,"a cat is sleeping"
|
||||
image2.jpg,"a dog is running"
|
||||
```
|
||||
|
||||
We have built a sample image dataset to help you test more conveniently. You can download this dataset using the following command:
|
||||
We have built a sample image dataset to help you test. You can download it with the following command:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
@@ -236,26 +304,27 @@ modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir
|
||||
|
||||
The dataset supports multiple image formats: `"jpg", "jpeg", "png", "webp"`.
|
||||
|
||||
The image resolution can be controlled via script parameters `--height` and `--width`. When both `--height` and `--width` are left empty, dynamic resolution will be enabled, allowing training with the actual width and height of each image in the dataset.
|
||||
Image size can be controlled by script arguments `--height` and `--width`. When `--height` and `--width` are left empty, dynamic resolution is enabled. The model will train using each image's actual width and height from the dataset.
|
||||
|
||||
**We strongly recommend using fixed-resolution training, as there may be load-balancing issues in multi-GPU training with dynamic resolution.**
|
||||
**We strongly recommend using fixed resolution for training, because there can be load balancing issues in multi-GPU training.**
|
||||
|
||||
When the model requires additional inputs—for instance, `kontext_images` required by the controllable model [`black-forest-labs/FLUX.1-Kontext-dev`](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)—please add corresponding columns in the dataset, for example:
|
||||
When the model needs extra inputs, for example, `kontext_images` required by controllable models like [`black-forest-labs/FLUX.1-Kontext-dev`](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev ), add the corresponding column to your dataset, for example:
|
||||
|
||||
```
|
||||
image,prompt,kontext_images
|
||||
image1.jpg,"a cat is sleeping",image1_reference.jpg
|
||||
```
|
||||
|
||||
If additional inputs include image files, you need to specify the column names to parse using the `--data_file_keys` parameter. You can add more column names accordingly, e.g., `--data_file_keys "image,kontext_images"`.
|
||||
If an extra input includes image files, you must specify the column name in the `--data_file_keys` argument. Add column names as needed, for example `--data_file_keys "image,kontext_images"`, and also enable `--extra_inputs "kontext_images"`.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 2: Load Model</summary>
|
||||
|
||||
Similar to the model loading logic during inference, you can directly configure the model to be loaded using its model ID. For example, during inference we load the model with the following configuration:
|
||||
Similar to model loading during inference, you can configure which models to load directly using model IDs. For example, during inference we load the model with this setting:
|
||||
|
||||
```python
|
||||
model_configs=[
|
||||
@@ -266,13 +335,13 @@ model_configs=[
|
||||
]
|
||||
```
|
||||
|
||||
Then during training, simply provide the following parameter to load the corresponding model:
|
||||
Then, during training, use the following parameter to load the same models:
|
||||
|
||||
```shell
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors"
|
||||
```
|
||||
|
||||
If you prefer to load the model from local files, as in the inference example:
|
||||
If you want to load models from local files, for example, during inference:
|
||||
|
||||
```python
|
||||
model_configs=[
|
||||
@@ -283,7 +352,7 @@ model_configs=[
|
||||
]
|
||||
```
|
||||
|
||||
Then during training, set it up as follows:
|
||||
Then during training, set it as:
|
||||
|
||||
```shell
|
||||
--model_paths '[
|
||||
@@ -296,23 +365,25 @@ Then during training, set it up as follows:
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 3: Configure Trainable Modules</summary>
|
||||
|
||||
The training framework supports both full-model training and LoRA-based fine-tuning. Below are some examples:
|
||||
|
||||
* Full training of the DiT module: `--trainable_models dit`
|
||||
* Training a LoRA model on the DiT module: `--lora_base_model dit --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" --lora_rank 32`
|
||||
|
||||
Additionally, since the training script loads multiple modules (text encoder, DiT, VAE), you need to remove prefixes when saving the model files. For example, when performing full DiT training or LoRA training on the DiT module, please set `--remove_prefix_in_ckpt pipe.dit.`
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 4: Launch the Training Script</summary>
|
||||
<summary>Step 3: Set Trainable Modules</summary>
|
||||
|
||||
We have written specific training commands for each model. Please refer to the table at the beginning of this document for details.
|
||||
The training framework supports training base models or LoRA models. Here are some examples:
|
||||
|
||||
* Full training of the DiT part: `--trainable_models dit`
|
||||
* Training a LoRA model on the DiT part: `--lora_base_model dit --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" --lora_rank 32`
|
||||
|
||||
Also, because the training script loads multiple modules (text encoder, dit, vae), you need to remove prefixes when saving model files. For example, when fully training the DiT part or training a LoRA model on the DiT part, set `--remove_prefix_in_ckpt pipe.dit.`
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 4: Start Training</summary>
|
||||
|
||||
We have written training commands for each model. Please refer to the table at the beginning of this document.
|
||||
|
||||
</details>
|
||||
|
||||
@@ -18,7 +18,7 @@ pip install -e .
|
||||
|
||||
## 快速开始
|
||||
|
||||
通过运行以下代码可以快速加载 FLUX.1-dev 模型并进行推理。
|
||||
通过运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -41,12 +41,21 @@ image.save("image.jpg")
|
||||
|
||||
## 模型总览
|
||||
|
||||
**FLUX 系列模型的全新框架支持正在开发中,敬请期待!**
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.1-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_inference_low_vram/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||
|[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||
|[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||
|[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||
|[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||
|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./model_inference/FLUX.1-dev-EliGen.py)|[code](./model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||
|[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||
|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||
|[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./model_inference/Step1X-Edit.py)|[code](./model_inference_low_vram/Step1X-Edit.py)|[code](./model_training/full/Step1X-Edit.sh)|[code](./model_training/validate_full/Step1X-Edit.py)|[code](./model_training/lora/Step1X-Edit.sh)|[code](./model_training/validate_lora/Step1X-Edit.py)|
|
||||
|[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./model_inference/FLEX.2-preview.py)|[code](./model_inference_low_vram/FLEX.2-preview.py)|[code](./model_training/full/FLEX.2-preview.sh)|[code](./model_training/validate_full/FLEX.2-preview.py)|[code](./model_training/lora/FLEX.2-preview.sh)|[code](./model_training/validate_lora/FLEX.2-preview.py)|
|
||||
|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./model_inference/Nexus-Gen-Editing.py)|[code](./model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./model_training/full/Nexus-Gen.sh)|[code](./model_training/validate_full/Nexus-Gen.py)|[code](./model_training/lora/Nexus-Gen.sh)|[code](./model_training/validate_lora/Nexus-Gen.py)|
|
||||
|
||||
## 模型推理
|
||||
|
||||
@@ -59,6 +68,9 @@ image.save("image.jpg")
|
||||
模型通过 `from_pretrained` 加载:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
@@ -95,7 +107,7 @@ ModelConfig(path=[
|
||||
])
|
||||
```
|
||||
|
||||
`from_pretrained` 还提供了额外的参数用于控制模型加载时的行为:
|
||||
`ModelConfig` 还提供了额外的参数用于控制模型加载时的行为:
|
||||
|
||||
* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。
|
||||
* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。
|
||||
@@ -123,9 +135,41 @@ pipe = FluxImagePipeline.from_pretrained(
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
`enable_vram_management` 函数提供了以下参数,用于控制显存使用情况:
|
||||
FP8 量化功能也是支持的:
|
||||
|
||||
* `vram_limit`: 显存占用量(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。将其设置为0时,将会实现理论最小显存占用。
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
FP8 量化和 offload 可同时开启:
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
开启显存管理后,框架会自动根据设备上的剩余显存确定显存管理策略。对于大多数 FLUX 系列模型,最低 8GB 显存即可进行推理。`enable_vram_management` 函数提供了以下参数,用于手动控制显存管理策略:
|
||||
|
||||
* `vram_limit`: 显存占用量限制(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。将其设置为0时,将会实现理论最小显存占用。
|
||||
* `vram_buffer`: 显存缓冲区大小(GB),默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。
|
||||
* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。
|
||||
|
||||
@@ -163,6 +207,25 @@ Pipeline 在推理阶段能够接收以下输入参数:
|
||||
* `controlnet_inputs`: ControlNet 模型的输入。
|
||||
* `ipadapter_images`: IP-Adapter 模型的输入图像。
|
||||
* `ipadapter_scale`: IP-Adapter 模型的控制强度。
|
||||
* `eligen_entity_prompts`: EliGen 模型的图像局部提示词。
|
||||
* `eligen_entity_masks`: EliGen 模型的局部提示词控制区域,与 `eligen_entity_prompts` 一一对应。
|
||||
* `eligen_enable_on_negative`: 是否在负向提示词一侧启用 EliGen,仅在 `cfg_scale > 1` 时生效。
|
||||
* `eligen_enable_inpaint`: 是否启用 EliGen 局部重绘。
|
||||
* `infinityou_id_image`: InfiniteYou 模型的人脸图像。
|
||||
* `infinityou_guidance`: InfiniteYou 模型的控制强度。
|
||||
* `flex_inpaint_image`: FLEX 模型用于局部重绘的图像。
|
||||
* `flex_inpaint_mask`: FLEX 模型用于局部重绘的区域。
|
||||
* `flex_control_image`: FLEX 模型用于结构控制的图像。
|
||||
* `flex_control_strength`: FLEX 模型用于结构控制的强度。
|
||||
* `flex_control_stop`: FLEX 模型结构控制的结束点,1表示全程启用,0.5表示在前半段启用,0表示不启用。
|
||||
* `step1x_reference_image`: Step1x-Edit 模型用于图像编辑的输入图像。
|
||||
* `lora_encoder_inputs`: LoRA 编码器的输入,格式为 ModelConfig 或本地路径。
|
||||
* `lora_encoder_scale`: LoRA 编码器的激活强度,默认值为1,数值越小,LoRA 激活越弱。
|
||||
* `tea_cache_l1_thresh`: TeaCache 的阈值,数值越大,速度越快,画面质量越差。请注意,开启 TeaCache 后推理速度并非均匀,因此进度条上显示的剩余时间将会变得不准确。
|
||||
* `tiled`: 是否启用 VAE 分块推理,默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。
|
||||
* `tile_size`: VAE 编解码阶段的分块大小,默认为 128,仅在 `tiled=True` 时生效。
|
||||
* `tile_stride`: VAE 编解码阶段的分块步长,默认为 64,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。
|
||||
* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。
|
||||
|
||||
</details>
|
||||
|
||||
@@ -190,7 +253,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。
|
||||
* 训练
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)数量。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
* 可训练模块
|
||||
@@ -205,7 +268,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 其他
|
||||
* `--align_to_opensource_format`: 是否将 FLUX DiT LoRA 的格式与开源版本对齐,仅对 FLUX.1-dev 和 FLUX.1-Kontext-dev 的 LoRA 训练生效。
|
||||
* `--align_to_opensource_format`: 是否将 FLUX DiT LoRA 的格式与开源版本对齐,仅对 LoRA 训练生效。
|
||||
|
||||
|
||||
此外,训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,在开始训练前运行 `accelerate config` 可配置 GPU 的相关参数。对于部分模型训练(例如模型的全量训练)脚本,我们提供了建议的 `accelerate` 配置文件,可在对应的训练脚本中查看。
|
||||
|
||||
19
examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py
Normal file
19
examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
|
||||
],
|
||||
)
|
||||
|
||||
for i in [0.1, 0.3, 0.5, 0.7, 0.9]:
|
||||
image = pipe(prompt="a cat on the beach", seed=2, value_controller_inputs=[i])
|
||||
image.save(f"value_control_{i}.jpg")
|
||||
40
examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py
Normal file
40
examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.enable_lora_magic()
|
||||
|
||||
lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors")
|
||||
pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA.
|
||||
|
||||
# Empty prompt can automatically activate LoRA capabilities.
|
||||
image = pipe(prompt="", seed=0, lora_encoder_inputs=lora)
|
||||
image.save("image_1.jpg")
|
||||
|
||||
image = pipe(prompt="", seed=0)
|
||||
image.save("image_1_origin.jpg")
|
||||
|
||||
# Prompt without trigger words can also activate LoRA capabilities.
|
||||
image = pipe(prompt="a car", seed=0, lora_encoder_inputs=lora)
|
||||
image.save("image_2.jpg")
|
||||
|
||||
image = pipe(prompt="a car", seed=0,)
|
||||
image.save("image_2_origin.jpg")
|
||||
|
||||
# Adjust the activation intensity through the scale parameter.
|
||||
image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0)
|
||||
image.save("image_3.jpg")
|
||||
|
||||
image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5)
|
||||
image.save("image_3_scale.jpg")
|
||||
29
examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py
Normal file
29
examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.enable_lora_magic()
|
||||
|
||||
pipe.load_lora(
|
||||
pipe.dit,
|
||||
ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"),
|
||||
hotload=True,
|
||||
)
|
||||
pipe.load_lora(
|
||||
pipe.dit,
|
||||
ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"),
|
||||
hotload=True,
|
||||
)
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image_fused.jpg")
|
||||
37
examples/flux/model_inference/Nexus-Gen-Editing.py
Normal file
37
examples/flux/model_inference/Nexus-Gen-Editing.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import importlib
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
if importlib.util.find_spec("transformers") is None:
|
||||
raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.")
|
||||
else:
|
||||
import transformers
|
||||
assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`."
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
nexus_gen_processor_config=ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg")
|
||||
ref_image = Image.open("data/examples/nexusgen/cat.jpg").convert("RGB")
|
||||
prompt = "Add a crown."
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt="",
|
||||
seed=42, cfg_scale=2.0, num_inference_steps=50,
|
||||
nexus_gen_reference_image=ref_image,
|
||||
height=512, width=512,
|
||||
)
|
||||
image.save("cat_crown.jpg")
|
||||
32
examples/flux/model_inference/Nexus-Gen-Generation.py
Normal file
32
examples/flux/model_inference/Nexus-Gen-Generation.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import importlib
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
if importlib.util.find_spec("transformers") is None:
|
||||
raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.")
|
||||
else:
|
||||
import transformers
|
||||
assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`."
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"),
|
||||
)
|
||||
|
||||
prompt = "一只可爱的猫咪"
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt="",
|
||||
seed=0, cfg_scale=3, num_inference_steps=50,
|
||||
height=1024, width=1024,
|
||||
)
|
||||
image.save("cat.jpg")
|
||||
51
examples/flux/model_inference_low_vram/FLEX.2-preview.py
Normal file
51
examples/flux/model_inference_low_vram/FLEX.2-preview.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
seed=0
|
||||
)
|
||||
image.save(f"image_1.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[200:400, 400:700] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save(f"image_mask.jpg")
|
||||
|
||||
inpaint_image = image
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
|
||||
seed=4
|
||||
)
|
||||
image.save(f"image_2_new.jpg")
|
||||
|
||||
control_image = Annotator("canny")(image)
|
||||
control_image.save("image_control.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_control_image=control_image,
|
||||
seed=4
|
||||
)
|
||||
image.save(f"image_3_new.jpg")
|
||||
55
examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py
Normal file
55
examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian long-haired female college student.",
|
||||
embedded_guidance=2.5,
|
||||
seed=1,
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="transform the style to anime style.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=2,
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
|
||||
image_3 = pipe(
|
||||
prompt="let her smile.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=3,
|
||||
)
|
||||
image_3.save("image_3.jpg")
|
||||
|
||||
image_4 = pipe(
|
||||
prompt="let the girl play basketball.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=4,
|
||||
)
|
||||
image_4.save("image_4.jpg")
|
||||
|
||||
image_5 = pipe(
|
||||
prompt="move the girl to a park, let her sit on a chair.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=5,
|
||||
)
|
||||
image_5.save("image_5.jpg")
|
||||
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn)
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
for i in [0.1, 0.3, 0.5, 0.7, 0.9]:
|
||||
image = pipe(prompt="a cat on the beach", seed=2, value_controller_inputs=[i])
|
||||
image.save(f"value_control_{i}.jpg")
|
||||
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a cat sitting on a chair",
|
||||
height=1024, width=1024,
|
||||
seed=8, rand_device="cuda",
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[100:350, 350: -300] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save("mask.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a cat sitting on a chair, wearing sunglasses",
|
||||
controlnet_inputs=[ControlNetInput(image=image_1, inpaint_mask=mask, scale=0.9)],
|
||||
height=1024, width=1024,
|
||||
seed=9, rand_device="cuda",
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
from diffsynth import download_models
|
||||
|
||||
|
||||
|
||||
download_models(["Annotators:Depth"])
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian girl, full body, red dress, summer",
|
||||
height=1024, width=1024,
|
||||
seed=6, rand_device="cuda",
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_canny = Annotator("canny")(image_1)
|
||||
image_depth = Annotator("depth")(image_1)
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a beautiful Asian girl, full body, red dress, winter",
|
||||
controlnet_inputs=[
|
||||
ControlNetInput(image=image_canny, scale=0.3, processor_id="canny"),
|
||||
ControlNetInput(image=image_depth, scale=0.3, processor_id="depth"),
|
||||
],
|
||||
height=1024, width=1024,
|
||||
seed=7, rand_device="cuda",
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a photo of a cat, highly detailed",
|
||||
height=768, width=768,
|
||||
seed=0, rand_device="cuda",
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_1 = image_1.resize((2048, 2048))
|
||||
image_2 = pipe(
|
||||
prompt="a photo of a cat, highly detailed",
|
||||
controlnet_inputs=[ControlNetInput(image=image_1, scale=0.7)],
|
||||
input_image=image_1,
|
||||
denoising_strength=0.99,
|
||||
height=2048, width=2048, tiled=True,
|
||||
seed=1, rand_device="cuda",
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
148
examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py
Normal file
148
examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import random
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from diffsynth import download_customized_models
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=3.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=50,
|
||||
embedded_guidance=3.5,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
download_from_modelscope = True
|
||||
if download_from_modelscope:
|
||||
model_id = "DiffSynth-Studio/Eligen"
|
||||
downloading_priority = ["ModelScope"]
|
||||
else:
|
||||
model_id = "modelscope/EliGen"
|
||||
downloading_priority = ["HuggingFace"]
|
||||
EliGen_path = download_customized_models(
|
||||
model_id=model_id,
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control",
|
||||
downloading_priority=downloading_priority)[0]
|
||||
pipe.load_lora(pipe.dit, EliGen_path, alpha=1)
|
||||
|
||||
# example 1
|
||||
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||
|
||||
# example 2
|
||||
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"]
|
||||
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||
|
||||
# example 3
|
||||
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||
|
||||
# example 4
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||
|
||||
# example 5
|
||||
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||
|
||||
# example 6
|
||||
global_prompt = "Snow White and the 6 Dwarfs."
|
||||
entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"]
|
||||
example(pipe, [8], 6, global_prompt, entity_prompts)
|
||||
|
||||
# example 7, same prompt with different seeds
|
||||
seeds = range(5, 9)
|
||||
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="google/siglip-so400m-patch14-384", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
origin_prompt = "a rabbit in a garden, colorful flowers"
|
||||
image = pipe(prompt=origin_prompt, height=1280, width=960, seed=42)
|
||||
image.save("style image.jpg")
|
||||
|
||||
image = pipe(prompt="A piggy", height=1280, width=960, seed=42,
|
||||
ipadapter_images=[image], ipadapter_scale=0.7)
|
||||
image.save("A piggy.jpg")
|
||||
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from modelscope import dataset_snapshot_download
|
||||
from modelscope import snapshot_download
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
snapshot_download(
|
||||
"ByteDance/InfiniteYou",
|
||||
allow_file_pattern="supports/insightface/models/antelopev2/*",
|
||||
local_dir="models/ByteDance/InfiniteYou",
|
||||
)
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/infiniteyou/*",
|
||||
)
|
||||
|
||||
height, width = 1024, 1024
|
||||
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
|
||||
controlnet_inputs = [ControlNetInput(image=controlnet_image, scale=1.0, processor_id="None")]
|
||||
|
||||
prompt = "A man, portrait, cinematic"
|
||||
id_image = "data/examples/infiniteyou/man.jpg"
|
||||
id_image = Image.open(id_image).convert('RGB')
|
||||
image = pipe(
|
||||
prompt=prompt, seed=1,
|
||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||
controlnet_inputs=controlnet_inputs,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
height=height, width=width,
|
||||
)
|
||||
image.save("man.jpg")
|
||||
|
||||
prompt = "A woman, portrait, cinematic"
|
||||
id_image = "data/examples/infiniteyou/woman.jpg"
|
||||
id_image = Image.open(id_image).convert('RGB')
|
||||
image = pipe(
|
||||
prompt=prompt, seed=1,
|
||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||
controlnet_inputs=controlnet_inputs,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
height=height, width=width,
|
||||
)
|
||||
image.save("woman.jpg")
|
||||
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
pipe.enable_lora_magic()
|
||||
|
||||
lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors")
|
||||
pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA.
|
||||
|
||||
# Empty prompt can automatically activate LoRA capabilities.
|
||||
image = pipe(prompt="", seed=0, lora_encoder_inputs=lora)
|
||||
image.save("image_1.jpg")
|
||||
|
||||
image = pipe(prompt="", seed=0)
|
||||
image.save("image_1_origin.jpg")
|
||||
|
||||
# Prompt without trigger words can also activate LoRA capabilities.
|
||||
image = pipe(prompt="a car", seed=0, lora_encoder_inputs=lora)
|
||||
image.save("image_2.jpg")
|
||||
|
||||
image = pipe(prompt="a car", seed=0,)
|
||||
image.save("image_2_origin.jpg")
|
||||
|
||||
# Adjust the activation intensity through the scale parameter.
|
||||
image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0)
|
||||
image.save("image_3.jpg")
|
||||
|
||||
image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5)
|
||||
image.save("image_3_scale.jpg")
|
||||
@@ -6,11 +6,11 @@ pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/FLUX.1-dev-LoRAFusion", origin_file_pattern="model.safetensors")
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="DiffSynth-Studio/FLUX.1-dev-LoRAFusion", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn)
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
27
examples/flux/model_inference_low_vram/FLUX.1-dev.py
Normal file
27
examples/flux/model_inference_low_vram/FLUX.1-dev.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
|
||||
image = pipe(prompt=prompt, seed=0)
|
||||
image.save("flux.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
seed=0, cfg_scale=2, num_inference_steps=50,
|
||||
)
|
||||
image.save("flux_cfg.jpg")
|
||||
38
examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py
Normal file
38
examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import importlib
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
if importlib.util.find_spec("transformers") is None:
|
||||
raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.")
|
||||
else:
|
||||
import transformers
|
||||
assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`."
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
|
||||
],
|
||||
nexus_gen_processor_config=ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg")
|
||||
ref_image = Image.open("data/examples/nexusgen/cat.jpg").convert("RGB")
|
||||
prompt = "Add a crown."
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt="",
|
||||
seed=42, cfg_scale=2.0, num_inference_steps=50,
|
||||
nexus_gen_reference_image=ref_image,
|
||||
height=512, width=512,
|
||||
)
|
||||
image.save("cat_crown.jpg")
|
||||
@@ -0,0 +1,33 @@
|
||||
import importlib
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
if importlib.util.find_spec("transformers") is None:
|
||||
raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.")
|
||||
else:
|
||||
import transformers
|
||||
assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`."
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
|
||||
],
|
||||
nexus_gen_processor_config=ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
prompt = "一只可爱的猫咪"
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt="",
|
||||
seed=0, cfg_scale=3, num_inference_steps=50,
|
||||
height=1024, width=1024,
|
||||
)
|
||||
image.save("cat.jpg")
|
||||
33
examples/flux/model_inference_low_vram/Step1X-Edit.py
Normal file
33
examples/flux/model_inference_low_vram/Step1X-Edit.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)
|
||||
image = pipe(
|
||||
prompt="draw red flowers in Chinese ink painting style",
|
||||
step1x_reference_image=image,
|
||||
width=832, height=1248, cfg_scale=6,
|
||||
seed=1, rand_device='cuda'
|
||||
)
|
||||
image.save("image_1.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt="add more flowers in Chinese ink painting style",
|
||||
step1x_reference_image=image,
|
||||
width=832, height=1248, cfg_scale=6,
|
||||
seed=2, rand_device='cuda'
|
||||
)
|
||||
image.save("image_2.jpg")
|
||||
12
examples/flux/model_training/full/FLEX.2-preview.sh
Normal file
12
examples/flux/model_training/full/FLEX.2-preview.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 200 \
|
||||
--model_id_with_origin_paths "ostris/Flex.2-preview:Flex.2-preview.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLEX.2-preview_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
14
examples/flux/model_training/full/FLUX.1-dev-AttriCtrl.sh
Normal file
14
examples/flux/model_training/full/FLUX.1-dev-AttriCtrl.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_attrictrl.csv \
|
||||
--data_file_keys "image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,DiffSynth-Studio/AttriCtrl-FLUX.1-Dev:models/brightness.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.value_controller.encoders.0." \
|
||||
--output_path "./models/train/FLUX.1-dev-AttriCtrl_full" \
|
||||
--trainable_models "value_controller" \
|
||||
--extra_inputs "value_controller_inputs" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -0,0 +1,14 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_inpaint.csv \
|
||||
--data_file_keys "image,controlnet_image,controlnet_inpaint_mask" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta:diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
|
||||
--output_path "./models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_full" \
|
||||
--trainable_models "controlnet" \
|
||||
--extra_inputs "controlnet_image,controlnet_inpaint_mask" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -0,0 +1,14 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
|
||||
--data_file_keys "image,controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-Controlnet-Union-alpha:diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
|
||||
--output_path "./models/train/FLUX.1-dev-Controlnet-Union-alpha_full" \
|
||||
--trainable_models "controlnet" \
|
||||
--extra_inputs "controlnet_image,controlnet_processor_id" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -0,0 +1,14 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
|
||||
--data_file_keys "image,controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,jasperai/Flux.1-dev-Controlnet-Upscaler:diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
|
||||
--output_path "./models/train/FLUX.1-dev-Controlnet-Upscaler_full" \
|
||||
--trainable_models "controlnet" \
|
||||
--extra_inputs "controlnet_image" \
|
||||
--use_gradient_checkpointing
|
||||
14
examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh
Normal file
14
examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_infiniteyou.csv \
|
||||
--data_file_keys "image,controlnet_image,infinityou_id_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/image_proj_model.bin,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe." \
|
||||
--output_path "./models/train/FLUX.1-dev-InfiniteYou_full" \
|
||||
--trainable_models "controlnet,image_proj_model" \
|
||||
--extra_inputs "controlnet_image,infinityou_id_image,infinityou_guidance" \
|
||||
--use_gradient_checkpointing
|
||||
14
examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh
Normal file
14
examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_lora_encoder.csv \
|
||||
--data_file_keys "image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev:model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.lora_encoder." \
|
||||
--output_path "./models/train/FLUX.1-dev-LoRA-Encoder_full" \
|
||||
--trainable_models "lora_encoder" \
|
||||
--extra_inputs "lora_encoder_inputs" \
|
||||
--use_gradient_checkpointing
|
||||
14
examples/flux/model_training/full/Nexus-Gen.sh
Normal file
14
examples/flux/model_training/full/Nexus-Gen.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_nexusgen_edit.csv \
|
||||
--data_file_keys "image,nexus_gen_reference_image" \
|
||||
--max_pixels 262144 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "DiffSynth-Studio/Nexus-GenV2:model*.safetensors,DiffSynth-Studio/Nexus-GenV2:edit_decoder.bin,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-NexusGen-Edit_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "nexus_gen_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
14
examples/flux/model_training/full/Step1X-Edit.sh
Normal file
14
examples/flux/model_training/full/Step1X-Edit.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_step1x.csv \
|
||||
--data_file_keys "image,step1x_reference_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Step1X-Edit_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "step1x_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
@@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
offload_optimizer_device: 'cpu'
|
||||
offload_param_device: 'cpu'
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
15
examples/flux/model_training/lora/FLEX.2-preview.sh
Normal file
15
examples/flux/model_training/lora/FLEX.2-preview.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "ostris/Flex.2-preview:Flex.2-preview.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLEX.2-preview_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
17
examples/flux/model_training/lora/FLUX.1-dev-AttriCtrl.sh
Normal file
17
examples/flux/model_training/lora/FLUX.1-dev-AttriCtrl.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_attrictrl.csv \
|
||||
--data_file_keys "image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,DiffSynth-Studio/AttriCtrl-FLUX.1-Dev:models/brightness.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-dev-AttriCtrl_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "value_controller_inputs" \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_inpaint.csv \
|
||||
--data_file_keys "image,controlnet_image,controlnet_inpaint_mask" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta:diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "controlnet_image,controlnet_inpaint_mask" \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
|
||||
--data_file_keys "image,controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-Controlnet-Union-alpha:diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-dev-Controlnet-Union-alpha_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "controlnet_image,controlnet_processor_id" \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
|
||||
--data_file_keys "image,controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,jasperai/Flux.1-dev-Controlnet-Upscaler:diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-dev-Controlnet-Upscaler_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "controlnet_image" \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
17
examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh
Normal file
17
examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_eligen.json \
|
||||
--data_file_keys "image,eligen_entity_masks" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-dev-EliGen_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--extra_inputs "eligen_entity_masks,eligen_entity_prompts" \
|
||||
--use_gradient_checkpointing
|
||||
17
examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh
Normal file
17
examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_ipadapter.csv \
|
||||
--data_file_keys "image,ipadapter_images" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-dev-IP-Adapter_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "ipadapter_images" \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
17
examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh
Normal file
17
examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_infiniteyou.csv \
|
||||
--data_file_keys "image,controlnet_image,infinityou_id_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/image_proj_model.bin,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-dev-InfiniteYou_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "controlnet_image,infinityou_id_image,infinityou_guidance" \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
17
examples/flux/model_training/lora/Nexus-Gen.sh
Normal file
17
examples/flux/model_training/lora/Nexus-Gen.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_nexusgen_edit.csv \
|
||||
--data_file_keys "image,nexus_gen_reference_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "DiffSynth-Studio/Nexus-GenV2:model*.safetensors,DiffSynth-Studio/Nexus-GenV2:edit_decoder.bin,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-NexusGen-Edit_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--extra_inputs "nexus_gen_reference_image" \
|
||||
--use_gradient_checkpointing
|
||||
17
examples/flux/model_training/lora/Step1X-Edit.sh
Normal file
17
examples/flux/model_training/lora/Step1X-Edit.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_step1x.csv \
|
||||
--data_file_keys "image,step1x_reference_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Step1X-Edit_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "step1x_reference_image" \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch, os, json
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -51,7 +51,7 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
||||
def forward_preprocess(self, data):
|
||||
# CFG-sensitive parameters
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
|
||||
# CFG-unsensitive parameters
|
||||
inputs_shared = {
|
||||
@@ -72,8 +72,14 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
||||
}
|
||||
|
||||
# Extra inputs
|
||||
controlnet_input = {}
|
||||
for extra_input in self.extra_inputs:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
if extra_input.startswith("controlnet_"):
|
||||
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
if len(controlnet_input) > 0:
|
||||
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
|
||||
|
||||
# Pipeline units will automatically process the input parameters.
|
||||
for unit in self.pipe.units:
|
||||
@@ -100,6 +106,7 @@ if __name__ == "__main__":
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
)
|
||||
|
||||
20
examples/flux/model_training/validate_full/FLEX.2-preview.py
Normal file
20
examples/flux/model_training/validate_full/FLEX.2-preview.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/FLEX.2-preview_full/epoch-0.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
image = pipe(prompt="dog,white and brown dog, sitting on wall, under pink flowers", seed=0)
|
||||
image.save("image_FLEX.2-preview_full.jpg")
|
||||
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/FLUX.1-dev-AttriCtrl_full/epoch-0.safetensors")
|
||||
pipe.value_controller.encoders[0].load_state_dict(state_dict)
|
||||
|
||||
image = pipe(prompt="a cat", seed=0, value_controller_inputs=0.1, rand_device="cuda")
|
||||
image.save("image_FLUX.1-dev-AttriCtrl_full.jpg")
|
||||
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth import load_state_dict
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_full/epoch-0.safetensors")
|
||||
pipe.controlnet.models[0].load_state_dict(state_dict)
|
||||
|
||||
image = pipe(
|
||||
prompt="a cat sitting on a chair, wearing sunglasses",
|
||||
controlnet_inputs=[ControlNetInput(
|
||||
image=Image.open("data/example_image_dataset/inpaint/image_1.jpg"),
|
||||
inpaint_mask=Image.open("data/example_image_dataset/inpaint/mask.jpg"),
|
||||
scale=0.9
|
||||
)],
|
||||
height=1024, width=1024,
|
||||
seed=0, rand_device="cuda",
|
||||
)
|
||||
image.save("image_FLUX.1-dev-Controlnet-Inpainting-Beta_full.jpg")
|
||||
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth import load_state_dict
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/FLUX.1-dev-Controlnet-Union-alpha_full/epoch-0.safetensors")
|
||||
pipe.controlnet.models[0].load_state_dict(state_dict)
|
||||
|
||||
image = pipe(
|
||||
prompt="a dog",
|
||||
controlnet_inputs=[ControlNetInput(
|
||||
image=Image.open("data/example_image_dataset/canny/image_1.jpg"),
|
||||
scale=0.9,
|
||||
processor_id="canny",
|
||||
)],
|
||||
height=768, width=768,
|
||||
seed=0, rand_device="cuda",
|
||||
)
|
||||
image.save("image_FLUX.1-dev-Controlnet-Union-alpha_full.jpg")
|
||||
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth import load_state_dict
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/FLUX.1-dev-Controlnet-Upscaler_full/epoch-0.safetensors")
|
||||
pipe.controlnet.models[0].load_state_dict(state_dict)
|
||||
|
||||
image = pipe(
|
||||
prompt="a dog",
|
||||
controlnet_inputs=[ControlNetInput(
|
||||
image=Image.open("data/example_image_dataset/upscale/image_1.jpg"),
|
||||
scale=0.9
|
||||
)],
|
||||
height=768, width=768,
|
||||
seed=0, rand_device="cuda",
|
||||
)
|
||||
image.save("image_FLUX.1-dev-Controlnet-Upscaler_full.jpg")
|
||||
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth import load_state_dict
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin"),
|
||||
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/FLUX.1-dev-InfiniteYou_full/epoch-0.safetensors")
|
||||
state_dict_projector = {i.replace("image_proj_model.", ""): state_dict[i] for i in state_dict if i.startswith("image_proj_model.")}
|
||||
pipe.image_proj_model.load_state_dict(state_dict_projector)
|
||||
state_dict_controlnet = {i.replace("controlnet.models.0.", ""): state_dict[i] for i in state_dict if i.startswith("controlnet.models.0.")}
|
||||
pipe.controlnet.models[0].load_state_dict(state_dict_controlnet)
|
||||
|
||||
image = pipe(
|
||||
prompt="a man with a red hat",
|
||||
controlnet_inputs=[ControlNetInput(
|
||||
image=Image.open("data/example_image_dataset/infiniteyou/image_1.jpg"),
|
||||
)],
|
||||
height=1024, width=1024,
|
||||
seed=0, rand_device="cuda",
|
||||
)
|
||||
image.save("image_FLUX.1-dev-InfiniteYou_full.jpg")
|
||||
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.enable_lora_magic()
|
||||
state_dict = load_state_dict("models/train/FLUX.1-dev-LoRA-Encoder_full/epoch-0.safetensors")
|
||||
pipe.lora_encoder.load_state_dict(state_dict)
|
||||
|
||||
lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors")
|
||||
pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA.
|
||||
|
||||
image = pipe(prompt="", seed=0, lora_encoder_inputs=lora)
|
||||
image.save("image_FLUX.1-dev-LoRA-Encoder_full.jpg")
|
||||
28
examples/flux/model_training/validate_full/Nexus-Gen.py
Normal file
28
examples/flux/model_training/validate_full/Nexus-Gen.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/FLUX.1-NexusGen-Edit_full/epoch-0.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
ref_image = Image.open("data/example_image_dataset/nexus_gen/image_1.png").convert("RGB")
|
||||
prompt = "Add a pair of sunglasses."
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt="",
|
||||
seed=42, cfg_scale=2.0, num_inference_steps=50,
|
||||
nexus_gen_reference_image=ref_image,
|
||||
height=512, width=512,
|
||||
)
|
||||
image.save("NexusGen-Edit_full.jpg")
|
||||
25
examples/flux/model_training/validate_full/Step1X-Edit.py
Normal file
25
examples/flux/model_training/validate_full/Step1X-Edit.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct"),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Step1X-Edit_full/epoch-0.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
image = pipe(
|
||||
prompt="Make the dog turn its head around.",
|
||||
step1x_reference_image=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)),
|
||||
height=768, width=768, cfg_scale=6,
|
||||
seed=0
|
||||
)
|
||||
image.save("image_Step1X-Edit_full.jpg")
|
||||
18
examples/flux/model_training/validate_lora/FLEX.2-preview.py
Normal file
18
examples/flux/model_training/validate_lora/FLEX.2-preview.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/FLEX.2-preview_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
image = pipe(prompt="dog,white and brown dog, sitting on wall, under pink flowers", seed=0)
|
||||
image.save("image_FLEX.2-preview_lora.jpg")
|
||||
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-AttriCtrl_lora/epoch-3.safetensors", alpha=1)
|
||||
|
||||
image = pipe(prompt="a cat", seed=0, value_controller_inputs=0.1, rand_device="cuda")
|
||||
image.save("image_FLUX.1-dev-AttriCtrl_lora.jpg")
|
||||
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
image = pipe(
|
||||
prompt="a cat sitting on a chair, wearing sunglasses",
|
||||
controlnet_inputs=[ControlNetInput(
|
||||
image=Image.open("data/example_image_dataset/inpaint/image_1.jpg"),
|
||||
inpaint_mask=Image.open("data/example_image_dataset/inpaint/mask.jpg"),
|
||||
scale=0.9
|
||||
)],
|
||||
height=1024, width=1024,
|
||||
seed=0, rand_device="cuda",
|
||||
)
|
||||
image.save("image_FLUX.1-dev-Controlnet-Inpainting-Beta_lora.jpg")
|
||||
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-Controlnet-Union-alpha_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
image = pipe(
|
||||
prompt="a dog",
|
||||
controlnet_inputs=[ControlNetInput(
|
||||
image=Image.open("data/example_image_dataset/canny/image_1.jpg"),
|
||||
scale=0.9,
|
||||
processor_id="canny",
|
||||
)],
|
||||
height=768, width=768,
|
||||
seed=0, rand_device="cuda",
|
||||
)
|
||||
image.save("image_FLUX.1-dev-Controlnet-Union-alpha_lora.jpg")
|
||||
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-Controlnet-Upscaler_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
image = pipe(
|
||||
prompt="a dog",
|
||||
controlnet_inputs=[ControlNetInput(
|
||||
image=Image.open("data/example_image_dataset/upscale/image_1.jpg"),
|
||||
scale=0.9
|
||||
)],
|
||||
height=768, width=768,
|
||||
seed=0, rand_device="cuda",
|
||||
)
|
||||
image.save("image_FLUX.1-dev-Controlnet-Upscaler_lora.jpg")
|
||||
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-EliGen_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
masks = [Image.open(f"data/example_image_dataset/eligen/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=1.0,
|
||||
num_inference_steps=50,
|
||||
embedded_guidance=3.5,
|
||||
seed=42,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"EliGen_lora.png")
|
||||
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"),
|
||||
ModelConfig(model_id="google/siglip-so400m-patch14-384"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-IP-Adapter_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
image = pipe(
|
||||
prompt="dog,white and brown dog, sitting on wall, under pink flowers",
|
||||
ipadapter_images=Image.open("data/example_image_dataset/1.jpg"),
|
||||
height=768, width=768,
|
||||
seed=0
|
||||
)
|
||||
image.save("image_FLUX.1-dev-IP-Adapter_lora.jpg")
|
||||
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin"),
|
||||
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-InfiniteYou_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
image = pipe(
|
||||
prompt="a man with a red hat",
|
||||
controlnet_inputs=[ControlNetInput(
|
||||
image=Image.open("data/example_image_dataset/infiniteyou/image_1.jpg"),
|
||||
)],
|
||||
height=1024, width=1024,
|
||||
seed=0, rand_device="cuda",
|
||||
)
|
||||
image.save("image_FLUX.1-dev-InfiniteYou_lora.jpg")
|
||||
26
examples/flux/model_training/validate_lora/Nexus-Gen.py
Normal file
26
examples/flux/model_training/validate_lora/Nexus-Gen.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/FLUX.1-NexusGen-Edit_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
ref_image = Image.open("data/example_image_dataset/nexus_gen/image_1.png").convert("RGB")
|
||||
prompt = "Add a pair of sunglasses."
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt="",
|
||||
seed=42, cfg_scale=1.0, num_inference_steps=50,
|
||||
nexus_gen_reference_image=ref_image,
|
||||
height=512, width=512,
|
||||
)
|
||||
image.save("NexusGen-Edit_lora.jpg")
|
||||
23
examples/flux/model_training/validate_lora/Step1X-Edit.py
Normal file
23
examples/flux/model_training/validate_lora/Step1X-Edit.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct"),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Step1X-Edit_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
image = pipe(
|
||||
prompt="Make the dog turn its head around.",
|
||||
step1x_reference_image=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)),
|
||||
height=768, width=768, cfg_scale=6,
|
||||
seed=0
|
||||
)
|
||||
image.save("image_Step1X-Edit_lora.jpg")
|
||||
@@ -1,45 +1,5 @@
|
||||
# Text to Video
|
||||
|
||||
In DiffSynth Studio, we can use some video models to generate videos.
|
||||
|
||||
### Example: Text-to-Video using CogVideoX-5B (Experimental)
|
||||
|
||||
See [cogvideo_text_to_video.py](cogvideo_text_to_video.py).
|
||||
|
||||
First, we generate a video using prompt "an astronaut riding a horse on Mars".
|
||||
|
||||
https://github.com/user-attachments/assets/4c91c1cd-e4a0-471a-bd8d-24d761262941
|
||||
|
||||
Then, we convert the astronaut to a robot.
|
||||
|
||||
https://github.com/user-attachments/assets/225a00a4-2bc8-4740-8e86-a64b460a29ec
|
||||
|
||||
Upscale the video using the model itself.
|
||||
|
||||
https://github.com/user-attachments/assets/c02cb30c-de60-473c-8242-32c67b3155ad
|
||||
|
||||
Make the video look smoother by interpolating frames.
|
||||
|
||||
https://github.com/user-attachments/assets/f0e465b4-45df-4435-ab10-7a084ca2b0a0
|
||||
|
||||
Here is another example.
|
||||
|
||||
First, we generate a video using prompt "a dog is running".
|
||||
|
||||
https://github.com/user-attachments/assets/e3696297-99f5-4d0c-a5ca-1d1566db85b4
|
||||
|
||||
Then, we add a blue collar to the dog.
|
||||
|
||||
https://github.com/user-attachments/assets/7ff22be7-4390-4d33-ae6c-53f6f056e18d
|
||||
|
||||
Upscale the video using the model itself.
|
||||
|
||||
https://github.com/user-attachments/assets/a909c32c-0b7d-495c-a53c-d23a99a3d3e9
|
||||
|
||||
Make the video look smoother by interpolating frames.
|
||||
|
||||
https://github.com/user-attachments/assets/ea37c150-97a0-4858-8003-0c2e5eef3331
|
||||
|
||||
### Example: Text-to-Video using AnimateDiff
|
||||
|
||||
Generate a video using a Stable Diffusion model and an AnimateDiff model. We can break the limitation of number of frames! See [sd_text_to_video.py](./sd_text_to_video.py).
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Wan 2.1
|
||||
# Wan
|
||||
|
||||
[切换到中文](./README_zh.md)
|
||||
|
||||
Wan 2.1 is a collection of video synthesis models open-sourced by Alibaba.
|
||||
Wan is a collection of video synthesis models open-sourced by Alibaba.
|
||||
|
||||
**DiffSynth-Studio has adopted a new inference and training framework. To use the previous version, please click [here](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c).**
|
||||
|
||||
@@ -18,6 +18,8 @@ pip install -e .
|
||||
|
||||
## Quick Start
|
||||
|
||||
You can quickly load the [Wan-AI/Wan2.1-T2V-1.3B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) model and run inference by executing the code below.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
@@ -46,6 +48,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
@@ -78,6 +83,9 @@ The following sections will help you understand our functionalities and write in
|
||||
The model is loaded using `from_pretrained`:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
@@ -116,11 +124,14 @@ ModelConfig(path=[
|
||||
])
|
||||
```
|
||||
|
||||
The `from_pretrained` function also provides additional parameters to control the behavior during model loading:
|
||||
The `ModelConfig` function provides additional parameters to control the behavior during model loading:
|
||||
|
||||
* `tokenizer_config`: Path to the tokenizer of the Wan model. Default value is `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`.
|
||||
* `local_model_path`: Path where downloaded models are saved. Default value is `"./models"`.
|
||||
* `skip_download`: Whether to skip downloading models. Default value is `False`. When your network cannot access [ModelScope](https://modelscope.cn/), manually download the necessary files and set this to `True`.
|
||||
|
||||
The `from_pretrained` function provides additional parameters to control the behavior during model loading:
|
||||
|
||||
* `tokenizer_config`: Path to the tokenizer of the Wan model. Default value is `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`.
|
||||
* `redirect_common_files`: Whether to redirect duplicate model files. Default value is `True`. Since the Wan series models include multiple base models, some modules like text encoder are shared across these models. To avoid redundant downloads, we redirect the model paths.
|
||||
* `use_usp`: Whether to enable Unified Sequence Parallel. Default value is `False`. Used for multi-GPU parallel inference.
|
||||
|
||||
@@ -177,11 +188,11 @@ pipe.enable_vram_management()
|
||||
|
||||
FP8 quantization significantly reduces VRAM usage but does not accelerate computations. Some models may experience issues such as blurry, torn, or distorted outputs due to insufficient precision when using FP8 quantization. Use FP8 quantization with caution.
|
||||
|
||||
The `enable_vram_management` function provides the following parameters to control VRAM usage:
|
||||
After enabling VRAM management, the framework will automatically decide the VRAM strategy based on available GPU memory. The `enable_vram_management` function has the following parameters to manually control the VRAM strategy:
|
||||
|
||||
* `vram_limit`: VRAM usage limit (in GB). By default, it uses all available VRAM on the device. Note that this is not an absolute limit; if the specified VRAM is insufficient but more VRAM is actually available, inference will proceed using the minimum required VRAM.
|
||||
* `vram_buffer`: Size of the VRAM buffer (in GB). Default is 0.5GB. Since certain large neural network layers may consume more VRAM unpredictably during their execution phase, a VRAM buffer is necessary. Ideally, this should match the maximum VRAM consumed by any single layer in the model.
|
||||
* `num_persistent_param_in_dit`: Number of persistent parameters in DiT models. By default, there is no limit. We plan to remove this parameter in the future, so please avoid relying on it.
|
||||
* `vram_limit`: VRAM usage limit in GB. By default, it uses all free VRAM on the device. Note that this is not an absolute limit. If the set VRAM is not enough but more VRAM is actually available, the model will run with minimal VRAM usage. Setting it to 0 achieves the theoretical minimum VRAM usage.
|
||||
* `vram_buffer`: VRAM buffer size in GB. Default is 0.5GB. A buffer is needed because larger neural network layers may use more VRAM than expected during loading. The optimal value is the VRAM used by the largest layer in the model.
|
||||
* `num_persistent_param_in_dit`: Number of parameters in the DiT model that stay in VRAM. Default is no limit. We plan to remove this parameter in the future. Do not rely on it.
|
||||
|
||||
</details>
|
||||
|
||||
@@ -236,6 +247,7 @@ The pipeline accepts the following input parameters during inference:
|
||||
* `num_frames`: Number of frames, default is 81. Must be a multiple of 4 plus 1; if not, it will be rounded up, minimum is 1.
|
||||
* `cfg_scale`: Classifier-free guidance scale, default is 5. Higher values increase adherence to the prompt but may cause visual artifacts.
|
||||
* `cfg_merge`: Whether to merge both sides of classifier-free guidance for unified inference. Default is `False`. This parameter currently only works for basic text-to-video and image-to-video models.
|
||||
* `switch_DiT_boundary`: The time point for switching between DiT models. Default value is 0.875. This parameter only takes effect for mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B).
|
||||
* `num_inference_steps`: Number of inference steps, default is 50.
|
||||
* `sigma_shift`: Parameter from Rectified Flow theory, default is 5. Higher values make the model stay longer at the initial denoising stage. Increasing this may improve video quality but may also cause inconsistency between generated videos and training data due to deviation from training behavior.
|
||||
* `motion_bucket_id`: Motion intensity, range [0, 100], applicable to motion control modules such as [`DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1`](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1). Larger values indicate more intense motion.
|
||||
@@ -271,6 +283,8 @@ The script includes the following parameters:
|
||||
* Models
|
||||
* `--model_paths`: Paths to load models. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.
|
||||
* `--max_timestep_boundary`: Maximum value of the timestep interval, ranging from 0 to 1. Default is 1. This needs to be manually set only when training mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B).
|
||||
* `--min_timestep_boundary`: Minimum value of the timestep interval, ranging from 0 to 1. Default is 1. This needs to be manually set only when training mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B).
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# 通义万相 2.1(Wan 2.1)
|
||||
# 通义万相(Wan)
|
||||
|
||||
[Switch to English](./README.md)
|
||||
|
||||
Wan 2.1 是由阿里巴巴通义实验室开源的一系列视频生成模型。
|
||||
Wan 是由阿里巴巴通义实验室开源的一系列视频生成模型。
|
||||
|
||||
**DiffSynth-Studio 启用了新的推理和训练框架,如需使用旧版本,请点击[这里](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c)。**
|
||||
|
||||
@@ -18,6 +18,8 @@ pip install -e .
|
||||
|
||||
## 快速开始
|
||||
|
||||
通过运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
@@ -46,6 +48,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
@@ -70,7 +75,6 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
以下部分将会帮助您理解我们的功能并编写推理代码。
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>加载模型</summary>
|
||||
@@ -78,6 +82,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
模型通过 `from_pretrained` 加载:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
@@ -116,11 +123,14 @@ ModelConfig(path=[
|
||||
])
|
||||
```
|
||||
|
||||
`from_pretrained` 还提供了额外的参数用于控制模型加载时的行为:
|
||||
`ModelConfig` 提供了额外的参数用于控制模型加载时的行为:
|
||||
|
||||
* `tokenizer_config`: Wan 模型的 tokenizer 路径,默认值为 `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`。
|
||||
* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。
|
||||
* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。
|
||||
|
||||
`from_pretrained` 提供了额外的参数用于控制模型加载时的行为:
|
||||
|
||||
* `tokenizer_config`: Wan 模型的 tokenizer 路径,默认值为 `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`。
|
||||
* `redirect_common_files`: 是否重定向重复模型文件,默认值为 `True`。由于 Wan 系列模型包括多个基础模型,每个基础模型的 text encoder 等模块都是相同的,为避免重复下载,我们会对模型路径进行重定向。
|
||||
* `use_usp`: 是否启用 Unified Sequence Parallel,默认值为 `False`。用于多 GPU 并行推理。
|
||||
|
||||
@@ -178,9 +188,9 @@ pipe.enable_vram_management()
|
||||
|
||||
FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在 FP8 量化下会出现精度不足导致的画面模糊、撕裂、失真问题,请谨慎使用 FP8 量化。
|
||||
|
||||
`enable_vram_management` 函数提供了以下参数,用于控制显存使用情况:
|
||||
开启显存管理后,框架会自动根据设备上的剩余显存确定显存管理策略。`enable_vram_management` 函数提供了以下参数,用于手动控制显存管理策略:
|
||||
|
||||
* `vram_limit`: 显存占用量(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。
|
||||
* `vram_limit`: 显存占用量限制(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。将其设置为0时,将会实现理论最小显存占用。
|
||||
* `vram_buffer`: 显存缓冲区大小(GB),默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。
|
||||
* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。
|
||||
|
||||
@@ -238,6 +248,7 @@ Pipeline 在推理阶段能够接收以下输入参数:
|
||||
* `num_frames`: 帧数,默认为 81。需设置为 4 的倍数 + 1,不满足时向上取整,最小值为 1。
|
||||
* `cfg_scale`: Classifier-free guidance 机制的数值,默认为 5。数值越大,提示词的控制效果越强,但画面崩坏的概率越大。
|
||||
* `cfg_merge`: 是否合并 Classifier-free guidance 的两侧进行统一推理,默认为 `False`。该参数目前仅在基础的文生视频和图生视频模型上生效。
|
||||
* `switch_DiT_boundary`: 切换 DiT 模型的时间点,默认值为 0.875,仅对多 DiT 的混合模型生效,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。
|
||||
* `num_inference_steps`: 推理次数,默认值为 50。
|
||||
* `sigma_shift`: Rectified Flow 理论中的参数,默认为 5。数值越大,模型在去噪的开始阶段停留的步骤数越多,可适当调大这个参数来提高画面质量,但会因生成过程与训练过程不一致导致生成的视频内容与训练数据存在差异。
|
||||
* `motion_bucket_id`: 运动幅度,范围为 [0, 100]。适用于速度控制模块,例如 [`DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1`](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1),数值越大,运动幅度越大。
|
||||
@@ -274,9 +285,11 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
|
||||
* 模型
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。
|
||||
* `--max_timestep_boundary`: Timestep 区间最大值,范围为 0~1,默认为 1,仅在多 DiT 的混合模型训练中需要手动设置,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。
|
||||
* `--min_timestep_boundary`: Timestep 区间最小值,范围为 0~1,默认为 1,仅在多 DiT 的混合模型训练中需要手动设置,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。
|
||||
* 训练
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)数量。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
* 可训练模块
|
||||
|
||||
32
examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py
Normal file
32
examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"]
|
||||
)
|
||||
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480))
|
||||
|
||||
video = pipe(
|
||||
prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
24
examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py
Normal file
24
examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
43
examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py
Normal file
43
examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
height=704, width=1248,
|
||||
num_frames=121,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
# Image-to-video
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"]
|
||||
)
|
||||
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704))
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
height=704, width=1248,
|
||||
input_image=input_image,
|
||||
num_frames=121,
|
||||
)
|
||||
save_video(video, "video2.mp4", fps=15, quality=5)
|
||||
35
examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh
Normal file
35
examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh
Normal file
@@ -0,0 +1,35 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-I2V-A14B_high_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "input_image" \
|
||||
--use_gradient_checkpointing_offload \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.875
|
||||
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-I2V-A14B_low_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "input_image" \
|
||||
--use_gradient_checkpointing_offload \
|
||||
--max_timestep_boundary 0.875 \
|
||||
--min_timestep_boundary 0
|
||||
31
examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh
Normal file
31
examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-T2V-A14B_high_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.875
|
||||
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-T2V-A14B_low_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--max_timestep_boundary 0.875 \
|
||||
--min_timestep_boundary 0
|
||||
14
examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh
Normal file
14
examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-TI2V-5B_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "input_image"
|
||||
37
examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh
Normal file
37
examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh
Normal file
@@ -0,0 +1,37 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-I2V-A14B_high_noise_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image" \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.875
|
||||
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-I2V-A14B_low_noise_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image" \
|
||||
--max_timestep_boundary 0.875 \
|
||||
--min_timestep_boundary 0
|
||||
36
examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh
Normal file
36
examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh
Normal file
@@ -0,0 +1,36 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-T2V-A14B_high_noise_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.875
|
||||
|
||||
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-T2V-A14B_low_noise_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--max_timestep_boundary 0.875 \
|
||||
--min_timestep_boundary 0
|
||||
16
examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh
Normal file
16
examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh
Normal file
@@ -0,0 +1,16 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--num_frames 49 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-TI2V-5B_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image"
|
||||
@@ -14,6 +14,8 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
max_timestep_boundary=1.0,
|
||||
min_timestep_boundary=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
@@ -45,6 +47,8 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
self.max_timestep_boundary = max_timestep_boundary
|
||||
self.min_timestep_boundary = min_timestep_boundary
|
||||
|
||||
|
||||
def forward_preprocess(self, data):
|
||||
@@ -69,6 +73,8 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
"cfg_merge": False,
|
||||
"vace_scale": 1,
|
||||
"max_timestep_boundary": self.max_timestep_boundary,
|
||||
"min_timestep_boundary": self.min_timestep_boundary,
|
||||
}
|
||||
|
||||
# Extra inputs
|
||||
@@ -77,6 +83,8 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
inputs_shared["input_image"] = data["video"][0]
|
||||
elif extra_input == "end_image":
|
||||
inputs_shared["end_image"] = data["video"][-1]
|
||||
elif extra_input == "reference_image" or extra_input == "vace_reference_image":
|
||||
inputs_shared[extra_input] = data[extra_input][0]
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
|
||||
@@ -106,6 +114,8 @@ if __name__ == "__main__":
|
||||
lora_rank=args.lora_rank,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
max_timestep_boundary=args.max_timestep_boundary,
|
||||
min_timestep_boundary=args.min_timestep_boundary,
|
||||
)
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user