mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
2 Commits
qwen-image
...
wan-lora-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdeae36cfb | ||
|
|
a2a720267e |
BIN
.github/workflows/logo.gif
vendored
BIN
.github/workflows/logo.gif
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 146 KiB |
2
.github/workflows/publish.yaml
vendored
2
.github/workflows/publish.yaml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install wheel
|
||||
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
||||
run: pip install wheel && pip install -r requirements.txt
|
||||
- name: Build DiffSynth
|
||||
run: python setup.py sdist bdist_wheel
|
||||
- name: Publish package to PyPI
|
||||
|
||||
552
README.md
552
README.md
@@ -1,394 +1,51 @@
|
||||
# DiffSynth-Studio
|
||||
|
||||
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
|
||||
|
||||
# DiffSynth Studio
|
||||
[](https://pypi.org/project/DiffSynth/)
|
||||
[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
|
||||
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||
|
||||
[切换到中文](./README_zh.md)
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
|
||||
Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
|
||||
|
||||
## Introduction
|
||||
|
||||
Welcome to the magic world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by [ModelScope](https://www.modelscope.cn/) team. We aim to foster technical innovation through framework development, bring together the power of the open-source community, and explore the limits of generative models!
|
||||
|
||||
DiffSynth currently includes two open-source projects:
|
||||
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technical exploration, for academia, providing support for more cutting-edge model capabilities.
|
||||
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment, for industry, offering higher computing performance and more stable features.
|
||||
|
||||
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) and [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) are the core projects behind ModelScope [AIGC zone](https://modelscope.cn/aigc/home), offering powerful AI content generation abilities. Come and try our carefully designed features and start your AI creation journey!
|
||||
|
||||
## Installation
|
||||
|
||||
Install from source (recommended):
|
||||
|
||||
```
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Other installation methods</summary>
|
||||
|
||||
Install from PyPI (version updates may be delayed; for latest features, install from source)
|
||||
|
||||
```
|
||||
pip install diffsynth
|
||||
```
|
||||
|
||||
If you meet problems during installation, they might be caused by upstream dependencies. Please check the docs of these packages:
|
||||
|
||||
* [torch](https://pytorch.org/get-started/locally/)
|
||||
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||
* [cmake](https://cmake.org)
|
||||
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
||||
|
||||
</details>
|
||||
|
||||
## Basic Framework
|
||||
|
||||
DiffSynth-Studio redesigns the inference and training pipelines for mainstream Diffusion models (including FLUX, Wan, etc.), enabling efficient memory management and flexible model training.
|
||||
|
||||
### Qwen-Image Series (🔥New Model)
|
||||
|
||||
Details: [./examples/qwen_image/](./examples/qwen_image/)
|
||||
|
||||

|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "A detailed portrait of a girl underwater, wearing a blue flowing dress, hair gently floating, clear light and shadow, surrounded by bubbles, calm expression, fine details, dreamy and beautiful."
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Model Overview</summary>
|
||||
|
||||
|Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
</details>
|
||||
|
||||
### FLUX Series
|
||||
|
||||
Detail page: [./examples/flux/](./examples/flux/)
|
||||
|
||||

|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Model Overview</summary>
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](./examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
||||
|[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||
|[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||
|[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||
|[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||
|[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||
|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||
|[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||
|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||
|[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
||||
|[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
||||
|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](./examples/flux/model_training/lora/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
### Wan Series
|
||||
|
||||
Detail page: [./examples/wanvideo/](./examples/wanvideo/)
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = pipe(
|
||||
prompt="A documentary photography style scene: a lively puppy rapidly running on green grass. The puppy has brown-yellow fur, upright ears, and looks focused and joyful. Sunlight shines on its body, making the fur appear soft and shiny. The background is an open field with occasional wildflowers, and faint blue sky and clouds in the distance. Strong sense of perspective captures the motion of the puppy and the vitality of the surrounding grass. Mid-shot side-moving view.",
|
||||
negative_prompt="Bright colors, overexposed, static, blurry details, subtitles, style, artwork, image, still, overall gray, worst quality, low quality, JPEG compression artifacts, ugly, deformed, extra fingers, poorly drawn hands, poorly drawn face, malformed limbs, fused fingers, still frame, messy background, three legs, crowded background people, walking backwards",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Model Overview</summary>
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/examples/wanmodel_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|
||||
</details>
|
||||
|
||||
### More Models
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Image Generation Models</summary>
|
||||
|
||||
Detail page: [./examples/image_synthesis/](./examples/image_synthesis/)
|
||||
|
||||
|FLUX|Stable Diffusion 3|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
|Kolors|Hunyuan-DiT|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
|Stable Diffusion|Stable Diffusion XL|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Video Generation Models</summary>
|
||||
|
||||
- HunyuanVideo: [./examples/HunyuanVideo/](./examples/HunyuanVideo/)
|
||||
|
||||
https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9
|
||||
|
||||
- StepVideo: [./examples/stepvideo/](./examples/stepvideo/)
|
||||
|
||||
https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b
|
||||
|
||||
- CogVideoX: [./examples/CogVideoX/](./examples/CogVideoX/)
|
||||
|
||||
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Image Quality Assessment Models</summary>
|
||||
|
||||
We have integrated a series of image quality assessment models. These models can be used for evaluating image generation models, alignment training, and similar tasks.
|
||||
|
||||
Detail page: [./examples/image_quality_metric/](./examples/image_quality_metric/)
|
||||
|
||||
* [ImageReward](https://github.com/THUDM/ImageReward)
|
||||
* [Aesthetic](https://github.com/christophschuhmann/improved-aesthetic-predictor)
|
||||
* [PickScore](https://github.com/yuvalkirstain/pickscore)
|
||||
* [CLIP](https://github.com/openai/CLIP)
|
||||
* [HPSv2](https://github.com/tgxs002/HPSv2)
|
||||
* [HPSv2.1](https://github.com/tgxs002/HPSv2)
|
||||
* [MPS](https://github.com/Kwai-Kolors/MPS)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## Innovative Achievements
|
||||
|
||||
DiffSynth-Studio is not just an engineering model framework, but also a platform for incubating innovative results.
|
||||
|
||||
<details>
|
||||
<summary>Nexus-Gen: Unified Architecture for Image Understanding, Generation, and Editing</summary>
|
||||
|
||||
- Detail page: https://github.com/modelscope/Nexus-Gen
|
||||
- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||
- Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||
- Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ArtAug: Aesthetic Enhancement for Image Generation Models</summary>
|
||||
|
||||
- Detail page: [./examples/ArtAug/](./examples/ArtAug/)
|
||||
- Paper: [ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||
- Online Demo: [ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|
||||
|
||||
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>EliGen: Precise Image Region Control</summary>
|
||||
|
||||
- Detail page: [./examples/EntityControl/](./examples/EntityControl/)
|
||||
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
- Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
|Entity Control Mask|Generated Image|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ExVideo: Extended Training for Video Generation Models</summary>
|
||||
|
||||
- Project Page: [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||
- Paper: [ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
|
||||
- Code Example: [./examples/ExVideo/](./examples/ExVideo/)
|
||||
- Model: [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
|
||||
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Diffutoon: High-Resolution Anime-Style Video Rendering</summary>
|
||||
|
||||
- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||
- Paper: [Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
|
||||
- Code Example: [./examples/Diffutoon/](./examples/Diffutoon/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>DiffSynth: The Initial Version of This Project</summary>
|
||||
|
||||
- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
|
||||
- Paper: [DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
|
||||
- Code Example: [./examples/diffsynth/](./examples/diffsynth/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## Update History
|
||||
- **August 7, 2025** We open-sourced the entity control LoRA of Qwen-Image, [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen). Qwen-Image-EliGen is able to achieve entity-level controlled text-to-image generation. See the [paper](https://arxiv.org/abs/2501.01097) for technical details. Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet).
|
||||
|
||||
- **August 5, 2025** We open-sourced the distilled acceleration model of Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), achieving approximately 5x speedup.
|
||||
|
||||
- **August 4, 2025** 🔥 Qwen-Image is now open source. Welcome the new member to the image generation model family!
|
||||
|
||||
- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) with a focus on aesthetic photography is comprehensively supported, including low-GPU-memory layer-by-layer offload, LoRA training and full training. See [./examples/flux/](./examples/flux/).
|
||||
|
||||
- **July 28, 2025** With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||
|
||||
- **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks.
|
||||
- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||
- Github Repo: https://github.com/modelscope/Nexus-Gen
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||
- Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||
- Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||
|
||||
<details>
|
||||
<summary>More</summary>
|
||||
|
||||
- **June 15, 2025** ModelScope's official evaluation framework, [EvalScope](https://github.com/modelscope/evalscope), now supports text-to-image generation evaluation. Try it with the [Best Practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide.
|
||||
|
||||
- **March 25, 2025** Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
|
||||
Welcome to the magic world of Diffusion models!
|
||||
|
||||
DiffSynth consists of two open-source projects:
|
||||
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technological exploration. Targeted at academia. Provides more cutting-edge technical support and novel inference capabilities.
|
||||
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
|
||||
DiffSynth-Studio is an open-source project aimed at exploring innovations in AIGC technology. We have integrated numerous open-source Diffusion models, including FLUX and Wan, among others. Through this open-source project, we hope to connect models within the open-source community and explore new technologies based on diffusion models.
|
||||
|
||||
Until now, DiffSynth-Studio has supported the following models:
|
||||
|
||||
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
||||
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
|
||||
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
|
||||
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
|
||||
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
||||
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
||||
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
||||
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
||||
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
||||
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
|
||||
## News
|
||||
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
|
||||
|
||||
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
|
||||
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||
|
||||
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||
@@ -464,4 +121,135 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
||||
- The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
||||
- The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
||||
|
||||
</details>
|
||||
|
||||
## Installation
|
||||
|
||||
Install from source code (recommended):
|
||||
|
||||
```
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Or install from pypi (There is a delay in the update. If you want to experience the latest features, please do not use this installation method.):
|
||||
|
||||
```
|
||||
pip install diffsynth
|
||||
```
|
||||
|
||||
If you encounter issues during installation, it may be caused by the packages we depend on. Please refer to the documentation of the package that caused the problem.
|
||||
|
||||
* [torch](https://pytorch.org/get-started/locally/)
|
||||
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||
* [cmake](https://cmake.org)
|
||||
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
||||
|
||||
## Usage (in Python code)
|
||||
|
||||
The Python examples are in [`examples`](./examples/). We provide an overview here.
|
||||
|
||||
### Download Models
|
||||
|
||||
Download the pre-set models. Model IDs can be found in [config file](/diffsynth/configs/model_config.py).
|
||||
|
||||
```python
|
||||
from diffsynth import download_models
|
||||
|
||||
download_models(["FLUX.1-dev", "Kolors"])
|
||||
```
|
||||
|
||||
Download your own models.
|
||||
|
||||
```python
|
||||
from diffsynth.models.downloader import download_from_huggingface, download_from_modelscope
|
||||
|
||||
# From Modelscope (recommended)
|
||||
download_from_modelscope("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.bin", "models/kolors/Kolors/vae")
|
||||
# From Huggingface
|
||||
download_from_huggingface("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.safetensors", "models/kolors/Kolors/vae")
|
||||
```
|
||||
|
||||
### Video Synthesis
|
||||
|
||||
#### Text-to-video using CogVideoX-5B
|
||||
|
||||
CogVideoX-5B is released by ZhiPu. We provide an improved pipeline, supporting text-to-video, video editing, self-upscaling and video interpolation. [`examples/video_synthesis`](./examples/video_synthesis/)
|
||||
|
||||
The video on the left is generated using the original text-to-video pipeline, while the video on the right is the result after editing and frame interpolation.
|
||||
|
||||
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
|
||||
|
||||
#### Long Video Synthesis
|
||||
|
||||
We trained extended video synthesis models, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
|
||||
|
||||
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||
|
||||
https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e
|
||||
|
||||
#### Toon Shading
|
||||
|
||||
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
|
||||
|
||||
#### Video Stylization
|
||||
|
||||
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
### Image Synthesis
|
||||
|
||||
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
|
||||
|
||||
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|
||||
|
||||
|FLUX|Stable Diffusion 3|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
|Kolors|Hunyuan-DiT|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
|Stable Diffusion|Stable Diffusion XL|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
## Usage (in WebUI)
|
||||
|
||||
Create stunning images using the painter, with assistance from AI!
|
||||
|
||||
https://github.com/user-attachments/assets/95265d21-cdd6-4125-a7cb-9fbcf6ceb7b0
|
||||
|
||||
**This video is not rendered in real-time.**
|
||||
|
||||
Before launching the WebUI, please download models to the folder `./models`. See [here](#download-models).
|
||||
|
||||
* `Gradio` version
|
||||
|
||||
```
|
||||
pip install gradio
|
||||
```
|
||||
|
||||
```
|
||||
python apps/gradio/DiffSynth_Studio.py
|
||||
```
|
||||
|
||||

|
||||
|
||||
* `Streamlit` version
|
||||
|
||||
```
|
||||
pip install streamlit streamlit-drawable-canvas
|
||||
```
|
||||
|
||||
```
|
||||
python -m streamlit run apps/streamlit/DiffSynth_Studio.py
|
||||
```
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
|
||||
|
||||
484
README_zh.md
484
README_zh.md
@@ -1,484 +0,0 @@
|
||||
# DiffSynth-Studio
|
||||
|
||||
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
|
||||
|
||||
[](https://pypi.org/project/DiffSynth/)
|
||||
[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
|
||||
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||
|
||||
[Switch to English](./README.md)
|
||||
|
||||
## 简介
|
||||
|
||||
欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
|
||||
|
||||
DiffSynth 目前包括两个开源项目:
|
||||
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索,面向学术界,提供更前沿的模型能力支持。
|
||||
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署,面向工业界,提供更高的计算性能与更稳定的功能。
|
||||
|
||||
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 作为魔搭社区 [AIGC 专区](https://modelscope.cn/aigc/home) 的核心技术支撑,提供了强大的AI生成内容能力。欢迎体验我们精心打造的产品化功能,开启您的AI创作之旅!
|
||||
|
||||
## 安装
|
||||
|
||||
从源码安装(推荐):
|
||||
|
||||
```
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>其他安装方式</summary>
|
||||
|
||||
从 pypi 安装(存在版本更新延迟,如需使用最新功能,请从源码安装)
|
||||
|
||||
```
|
||||
pip install diffsynth
|
||||
```
|
||||
|
||||
如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档:
|
||||
|
||||
* [torch](https://pytorch.org/get-started/locally/)
|
||||
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||
* [cmake](https://cmake.org)
|
||||
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## 基础框架
|
||||
|
||||
DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设计了推理和训练流水线,能够实现高效的显存管理、灵活的模型训练。
|
||||
|
||||
### Qwen-Image 系列 (🔥新模型)
|
||||
|
||||
详细页面:[./examples/qwen_image/](./examples/qwen_image/)
|
||||
|
||||

|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>模型总览</summary>
|
||||
|
||||
|模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|
||||
</details>
|
||||
|
||||
### FLUX 系列
|
||||
|
||||
详细页面:[./examples/flux/](./examples/flux/)
|
||||
|
||||

|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>模型总览</summary>
|
||||
|
||||
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](./examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
||||
|[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||
|[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||
|[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||
|[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||
|[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||
|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||
|[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||
|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||
|[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
||||
|[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
||||
|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](./examples/flux/model_training/lora/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
||||
|
||||
</details>
|
||||
|
||||
### Wan 系列
|
||||
|
||||
详细页面:[./examples/wanvideo/](./examples/wanvideo/)
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>模型总览</summary>
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/examples/wanmodel_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
### 更多模型
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>图像生成模型</summary>
|
||||
|
||||
详细页面:[./examples/image_synthesis/](./examples/image_synthesis/)
|
||||
|
||||
|FLUX|Stable Diffusion 3|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
|Kolors|Hunyuan-DiT|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
|Stable Diffusion|Stable Diffusion XL|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>视频生成模型</summary>
|
||||
|
||||
- HunyuanVideo:[./examples/HunyuanVideo/](./examples/HunyuanVideo/)
|
||||
|
||||
https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9
|
||||
|
||||
- StepVideo:[./examples/stepvideo/](./examples/stepvideo/)
|
||||
|
||||
https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b
|
||||
|
||||
- CogVideoX:[./examples/CogVideoX/](./examples/CogVideoX/)
|
||||
|
||||
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>图像质量评估模型</summary>
|
||||
|
||||
我们集成了一系列图像质量评估模型,这些模型可以用于图像生成模型的评测、对齐训练等场景中。
|
||||
|
||||
详细页面:[./examples/image_quality_metric/](./examples/image_quality_metric/)
|
||||
|
||||
* [ImageReward](https://github.com/THUDM/ImageReward)
|
||||
* [Aesthetic](https://github.com/christophschuhmann/improved-aesthetic-predictor)
|
||||
* [PickScore](https://github.com/yuvalkirstain/pickscore)
|
||||
* [CLIP](https://github.com/openai/CLIP)
|
||||
* [HPSv2](https://github.com/tgxs002/HPSv2)
|
||||
* [HPSv2.1](https://github.com/tgxs002/HPSv2)
|
||||
* [MPS](https://github.com/Kwai-Kolors/MPS)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## 创新成果
|
||||
|
||||
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
|
||||
|
||||
<details>
|
||||
<summary>Nexus-Gen: 统一架构的图像理解、生成、编辑</summary>
|
||||
|
||||
- 详细页面:https://github.com/modelscope/Nexus-Gen
|
||||
- 论文:[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||
- 数据集:[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||
- 在线体验:[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>ArtAug: 图像生成模型的美学提升</summary>
|
||||
|
||||
- 详细页面:[./examples/ArtAug/](./examples/ArtAug/)
|
||||
- 论文:[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
|
||||
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||
- 在线体验:[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|
||||
|
||||
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>EliGen: 精准的图像分区控制</summary>
|
||||
|
||||
- 详细页面:[./examples/EntityControl/](./examples/EntityControl/)
|
||||
- 论文:[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||
- 在线体验:[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
- 数据集:[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
|实体控制区域|生成图像|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>ExVideo: 视频生成模型的扩展训练</summary>
|
||||
|
||||
- 项目页面:[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||
- 论文:[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
|
||||
- 代码样例:[./examples/ExVideo/](./examples/ExVideo/)
|
||||
- 模型:[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
|
||||
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Diffutoon: 高分辨率动漫风格视频渲染</summary>
|
||||
|
||||
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||
- 论文:[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
|
||||
- 代码样例:[./examples/Diffutoon/](./examples/Diffutoon/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>DiffSynth: 本项目的初代版本</summary>
|
||||
|
||||
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
|
||||
- 论文:[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
|
||||
- 代码样例:[./examples/diffsynth/](./examples/diffsynth/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## 更新历史
|
||||
- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
|
||||
|
||||
- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
|
||||
|
||||
- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
|
||||
|
||||
- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
|
||||
|
||||
- **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
|
||||
|
||||
- **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。
|
||||
- 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||
- Github 仓库: https://github.com/modelscope/Nexus-Gen
|
||||
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||
- 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||
- 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||
|
||||
<details>
|
||||
<summary>更多</summary>
|
||||
|
||||
- **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。
|
||||
|
||||
- **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。
|
||||
|
||||
- **2025年3月31日** 我们支持 InfiniteYou,一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。
|
||||
|
||||
- **2025年3月13日** 我们支持 HunyuanVideo-I2V,即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
|
||||
|
||||
- **2025年2月25日** 我们支持 Wan-Video,这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。
|
||||
|
||||
- **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)!先进的视频合成模型!详见 [./examples/stepvideo](./examples/stepvideo/)。
|
||||
|
||||
- **2024年12月31日** 我们提出 EliGen,一种用于精确实体级别控制的文本到图像生成的新框架,并辅以修复融合管道,将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型,如 IP-Adapter 和 In-Context LoRA,提升其通用性。更多详情,请见 [./examples/EntityControl](./examples/EntityControl/)。
|
||||
- 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||
- 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
- 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
- **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理,使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频,或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
|
||||
|
||||
- **2024年12月18日** 我们提出 ArtAug,一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学理解融入 FLUX.1-dev,从而提升了生成图像的质量。
|
||||
- 论文: https://arxiv.org/abs/2412.12888
|
||||
- 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
|
||||
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||
- 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线)
|
||||
|
||||
- **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型,并且可以自由组合,即使它们的结构不同。此外,ControlNet 模型兼容高分辨率优化和分区控制技术,能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。
|
||||
|
||||
- **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。
|
||||
|
||||
- **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能,包括:
|
||||
- 文本到视频
|
||||
- 视频编辑
|
||||
- 自我超分
|
||||
- 视频插帧
|
||||
|
||||
- **2024年8月22日** 我们实现了一个有趣的画笔功能,支持所有文生图模型。现在,您可以在 AI 的辅助下使用画笔创作惊艳的图像了!
|
||||
- 在我们的 [WebUI](#usage-in-webui) 中使用它。
|
||||
|
||||
- **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。
|
||||
- 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md)
|
||||
- LoRA、ControlNet 和其他附加模型将很快推出。
|
||||
|
||||
- **2024年6月21日** 我们提出 ExVideo,一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展,实现了长达 128 帧的长视频生成。
|
||||
- [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||
- 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。
|
||||
- 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。
|
||||
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。
|
||||
- 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo!
|
||||
|
||||
- **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然,我仍会参与后续的开发和维护工作。
|
||||
|
||||
- **2024年1月29日** 我们提出 Diffutoon,这是一个出色的卡通着色解决方案。
|
||||
- [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||
- 源代码已在此项目中发布。
|
||||
- 技术报告(IJCAI 2024)已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。
|
||||
|
||||
- **2023年12月8日** 我们决定启动一个新项目,旨在释放扩散模型的潜力,尤其是在视频合成方面。该项目的开发工作正式开始。
|
||||
|
||||
- **2023年11月15日** 我们提出 FastBlend,一种强大的视频去闪烁算法。
|
||||
- sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。
|
||||
- 演示视频已在 Bilibili 上展示,包含三个任务:
|
||||
- [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||
- [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||
- [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。
|
||||
- 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。
|
||||
|
||||
- **2023年10月1日** 我们发布了该项目的早期版本,名为 FastSDXL。这是构建一个扩散引擎的初步尝试。
|
||||
- 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。
|
||||
- FastSDXL 包含一个可训练的 OLSS 调度器,以提高效率。
|
||||
- OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。
|
||||
- 技术报告(CIKM 2023)已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。
|
||||
- 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。
|
||||
- 由于 OLSS 需要额外训练,我们未在本项目中实现它。
|
||||
|
||||
- **2023年8月29日** 我们提出 DiffSynth,一个视频合成框架。
|
||||
- [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。
|
||||
- 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。
|
||||
- 技术报告(ECML PKDD 2024)已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。
|
||||
|
||||
</details>
|
||||
@@ -1,382 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import random
|
||||
import json
|
||||
import gradio as gr
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
|
||||
# pip install pydantic==2.10.6
|
||||
# pip install gradio==5.4.0
|
||||
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen", allow_file_pattern="model.safetensors")
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/*")
|
||||
example_json = 'data/examples/eligen/qwen-image/ui_examples.json'
|
||||
with open(example_json, 'r') as f:
|
||||
examples = json.load(f)['examples']
|
||||
|
||||
for idx in range(len(examples)):
|
||||
example_id = examples[idx]['example_id']
|
||||
entity_prompts = examples[idx]['local_prompt_list']
|
||||
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
|
||||
def create_canvas_data(background, masks):
|
||||
if background.shape[-1] == 3:
|
||||
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
|
||||
layers = []
|
||||
for mask in masks:
|
||||
if mask is not None:
|
||||
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
|
||||
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
|
||||
layer[..., -1] = mask_single_channel
|
||||
layers.append(layer)
|
||||
else:
|
||||
layers.append(np.zeros_like(background))
|
||||
|
||||
composite = background.copy()
|
||||
for layer in layers:
|
||||
if layer.size > 0:
|
||||
composite = np.where(layer[..., -1:] > 0, layer, composite)
|
||||
return {
|
||||
"background": background,
|
||||
"layers": layers,
|
||||
"composite": composite,
|
||||
}
|
||||
|
||||
def load_example(load_example_button):
|
||||
example_idx = int(load_example_button.split()[-1]) - 1
|
||||
example = examples[example_idx]
|
||||
result = [
|
||||
50,
|
||||
example["global_prompt"],
|
||||
example["negative_prompt"],
|
||||
example["seed"],
|
||||
*example["local_prompt_list"],
|
||||
]
|
||||
num_entities = len(example["local_prompt_list"])
|
||||
result += [""] * (config["max_num_painter_layers"] - num_entities)
|
||||
masks = []
|
||||
for mask in example["mask_lists"]:
|
||||
mask_single_channel = np.array(mask.convert("L"))
|
||||
masks.append(mask_single_channel)
|
||||
for _ in range(config["max_num_painter_layers"] - len(masks)):
|
||||
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
|
||||
masks.append(blank_mask)
|
||||
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
|
||||
canvas_data_list = []
|
||||
for mask in masks:
|
||||
canvas_data = create_canvas_data(background, [mask])
|
||||
canvas_data_list.append(canvas_data)
|
||||
result.extend(canvas_data_list)
|
||||
return result
|
||||
|
||||
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
|
||||
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
|
||||
print(f'save to {save_dir}')
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
for i, mask in enumerate(masks):
|
||||
save_path = os.path.join(save_dir, f'{i}.png')
|
||||
mask.save(save_path)
|
||||
sample = {
|
||||
"global_prompt": global_prompt,
|
||||
"mask_prompts": mask_prompts,
|
||||
"seed": seed,
|
||||
}
|
||||
with open(os.path.join(save_dir, f"prompts.json"), 'w', encoding='utf-8') as f:
|
||||
json.dump(sample, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
if mask is None:
|
||||
continue
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
if mask_bbox is None:
|
||||
continue
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
return result
|
||||
|
||||
config = {
|
||||
"max_num_painter_layers": 8,
|
||||
"max_num_model_cache": 1,
|
||||
}
|
||||
|
||||
model_dict = {}
|
||||
|
||||
def load_model(model_type='qwen-image'):
|
||||
global model_dict
|
||||
model_key = f"{model_type}"
|
||||
if model_key in model_dict:
|
||||
return model_dict[model_key]
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors")
|
||||
model_dict[model_key] = pipe
|
||||
return pipe
|
||||
|
||||
load_model('qwen-image')
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown(
|
||||
"""## EliGen: Entity-Level Controllable Text-to-Image Model
|
||||
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
|
||||
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
|
||||
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
|
||||
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
|
||||
"""
|
||||
)
|
||||
|
||||
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
|
||||
main_interface = gr.Column(visible=False)
|
||||
|
||||
def initialize_model():
|
||||
try:
|
||||
load_model('qwen-image')
|
||||
return {
|
||||
loading_status: gr.update(value="Model loaded successfully!", visible=False),
|
||||
main_interface: gr.update(visible=True),
|
||||
}
|
||||
except Exception as e:
|
||||
print(f'Failed to load model with error: {e}')
|
||||
return {
|
||||
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
|
||||
main_interface: gr.update(visible=True),
|
||||
}
|
||||
|
||||
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
|
||||
|
||||
with main_interface:
|
||||
with gr.Row():
|
||||
local_prompt_list = []
|
||||
canvas_list = []
|
||||
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
|
||||
with gr.Column(scale=382, min_width=100):
|
||||
model_type = gr.State('qwen-image')
|
||||
with gr.Accordion(label="Global prompt"):
|
||||
prompt = gr.Textbox(label="Global Prompt", lines=3)
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", value="", lines=3)
|
||||
with gr.Accordion(label="Inference Options", open=True):
|
||||
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
|
||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
|
||||
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=4.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
||||
with gr.Accordion(label="Inpaint Input Image", open=False, visible=False):
|
||||
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
|
||||
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
|
||||
|
||||
with gr.Column():
|
||||
reset_input_button = gr.Button(value="Reset Inpaint Input")
|
||||
send_input_to_painter = gr.Button(value="Set as painter's background")
|
||||
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
|
||||
def reset_input_image(input_image):
|
||||
return None
|
||||
|
||||
with gr.Column(scale=618, min_width=100):
|
||||
with gr.Accordion(label="Entity Painter"):
|
||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
||||
with gr.Tab(label=f"Entity {painter_layer_id}"):
|
||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
||||
canvas = gr.ImageEditor(
|
||||
canvas_size=(1024, 1024),
|
||||
sources=None,
|
||||
layers=False,
|
||||
interactive=True,
|
||||
image_mode="RGBA",
|
||||
brush=gr.Brush(
|
||||
default_size=50,
|
||||
default_color="#000000",
|
||||
colors=["#000000"],
|
||||
),
|
||||
label="Entity Mask Painter",
|
||||
key=f"canvas_{painter_layer_id}",
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
|
||||
def resize_canvas(height, width, canvas):
|
||||
if canvas is None or canvas["background"] is None:
|
||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
||||
h, w = canvas["background"].shape[:2]
|
||||
if h != height or width != w:
|
||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
||||
else:
|
||||
return canvas
|
||||
local_prompt_list.append(local_prompt)
|
||||
canvas_list.append(canvas)
|
||||
with gr.Accordion(label="Results"):
|
||||
run_button = gr.Button(value="Generate", variant="primary")
|
||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
||||
with gr.Column():
|
||||
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
|
||||
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
|
||||
real_output = gr.State(None)
|
||||
mask_out = gr.State(None)
|
||||
|
||||
@gr.on(
|
||||
inputs=[model_type, prompt, negative_prompt, cfg_scale, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
|
||||
outputs=[output_image, real_output, mask_out],
|
||||
triggers=run_button.click
|
||||
)
|
||||
def generate_image(model_type, prompt, negative_prompt, cfg_scale, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
|
||||
pipe = load_model(model_type)
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"cfg_scale": cfg_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"progress_bar_cmd": progress.tqdm,
|
||||
}
|
||||
# if input_image is not None:
|
||||
# input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
|
||||
# input_params["enable_eligen_inpaint"] = True
|
||||
|
||||
local_prompt_list, canvas_list = (
|
||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
||||
)
|
||||
local_prompts, masks = [], []
|
||||
for local_prompt, canvas in zip(local_prompt_list, canvas_list):
|
||||
if isinstance(local_prompt, str) and len(local_prompt) > 0:
|
||||
local_prompts.append(local_prompt)
|
||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
||||
entity_prompts = None if len(local_prompts) == 0 else local_prompts
|
||||
entity_masks = None if len(masks) == 0 or entity_prompts is None else masks
|
||||
input_params.update({
|
||||
"eligen_entity_prompts": entity_prompts,
|
||||
"eligen_entity_masks": entity_masks,
|
||||
})
|
||||
torch.manual_seed(seed)
|
||||
save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
|
||||
image = pipe(**input_params)
|
||||
masks = [mask.resize(image.size) for mask in masks]
|
||||
image_with_mask = visualize_masks(image, masks, local_prompts)
|
||||
|
||||
real_output = gr.State(image)
|
||||
mask_out = gr.State(image_with_mask)
|
||||
|
||||
if return_with_mask:
|
||||
return image_with_mask, real_output, mask_out
|
||||
return image, real_output, mask_out
|
||||
|
||||
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
|
||||
def send_input_to_painter_background(input_image, *canvas_list):
|
||||
if input_image is None:
|
||||
return tuple(canvas_list)
|
||||
for canvas in canvas_list:
|
||||
h, w = canvas["background"].shape[:2]
|
||||
canvas["background"] = input_image.resize((w, h))
|
||||
return tuple(canvas_list)
|
||||
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
||||
def send_output_to_painter_background(real_output, *canvas_list):
|
||||
if real_output is None:
|
||||
return tuple(canvas_list)
|
||||
for canvas in canvas_list:
|
||||
h, w = canvas["background"].shape[:2]
|
||||
canvas["background"] = real_output.value.resize((w, h))
|
||||
return tuple(canvas_list)
|
||||
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
|
||||
def show_output(return_with_mask, real_output, mask_out):
|
||||
if return_with_mask:
|
||||
return mask_out.value
|
||||
else:
|
||||
return real_output.value
|
||||
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
|
||||
def send_output_to_pipe_input(real_output):
|
||||
return real_output.value
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## Examples")
|
||||
for i in range(0, len(examples), 2):
|
||||
with gr.Row():
|
||||
if i < len(examples):
|
||||
example = examples[i]
|
||||
with gr.Column():
|
||||
example_image = gr.Image(
|
||||
value=f"data/examples/eligen/qwen-image/example_{example['example_id']}/example_image.png",
|
||||
label=example["description"],
|
||||
interactive=False,
|
||||
width=1024,
|
||||
height=512
|
||||
)
|
||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
||||
load_example_button.click(
|
||||
load_example,
|
||||
inputs=[load_example_button],
|
||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
||||
)
|
||||
|
||||
if i + 1 < len(examples):
|
||||
example = examples[i + 1]
|
||||
with gr.Column():
|
||||
example_image = gr.Image(
|
||||
value=f"data/examples/eligen/qwen-image/example_{example['example_id']}/example_image.png",
|
||||
label=example["description"],
|
||||
interactive=False,
|
||||
width=1024,
|
||||
height=512
|
||||
)
|
||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
||||
load_example_button.click(
|
||||
load_example,
|
||||
inputs=[load_example_button],
|
||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
||||
)
|
||||
app.config["show_progress"] = "hidden"
|
||||
app.launch(share=False)
|
||||
@@ -58,24 +58,8 @@ from ..models.stepvideo_dit import StepVideoModel
|
||||
from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
|
||||
from ..models.flux_value_control import SingleValueEncoder
|
||||
|
||||
from ..lora.flux_lora import FluxLoraPatcher
|
||||
from ..models.flux_lora_encoder import FluxLoRAEncoder
|
||||
|
||||
from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
|
||||
from ..models.nexus_gen import NexusGenAutoregressiveModel
|
||||
|
||||
from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter
|
||||
from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -112,9 +96,7 @@ model_loader_configs = [
|
||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||
(None, "0629116fce1472503a66992f96f3eb1a", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||
@@ -124,7 +106,6 @@ model_loader_configs = [
|
||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "43ad5aaa27dd4ee01b832ed16773fa52", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
|
||||
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
@@ -139,36 +120,11 @@ model_loader_configs = [
|
||||
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
(None, "e1de6c02cdac79f8b739f4d3698cd216", ["wan_video_vae"], [WanVideoVAE38], "civitai"),
|
||||
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
||||
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
|
||||
(None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"),
|
||||
(None, "77c2e4dd2440269eb33bfaa0d004f6ab", ["flux_lora_encoder"], [FluxLoRAEncoder], "civitai"),
|
||||
(None, "3e6c61b0f9471135fc9c6d6a98e98b6d", ["flux_dit", "nexus_gen_generation_adapter"], [FluxDiT, NexusGenAdapter], "civitai"),
|
||||
(None, "63c969fd37cce769a90aa781fbff5f81", ["flux_dit", "nexus_gen_editing_adapter"], [FluxDiT, NexusGenImageEmbeddingMerger], "civitai"),
|
||||
(None, "2bd19e845116e4f875a0a048e27fc219", ["nexus_gen_llm"], [NexusGenAutoregressiveModel], "civitai"),
|
||||
(None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
|
||||
(None, "ae9d13bfc578702baf6445d2cf3d1d46", ["qwen_image_accelerate_adapter"], [QwenImageAccelerateAdapter], "civitai"),
|
||||
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
|
||||
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -184,7 +140,6 @@ huggingface_model_loader_configs = [
|
||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
||||
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
||||
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
||||
("Qwen2_5_VLForConditionalGeneration", "diffsynth.models.qwenvl", "qwenvl", "Qwen25VL_7b_Embedder"),
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
@@ -73,10 +73,9 @@ def usp_dit_forward(self,
|
||||
return custom_forward
|
||||
|
||||
# Context Parallel
|
||||
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||
x = chunks[get_sequence_parallel_rank()]
|
||||
x = torch.chunk(
|
||||
x, get_sequence_parallel_world_size(),
|
||||
dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
@@ -100,7 +99,6 @@ def usp_dit_forward(self,
|
||||
|
||||
# Context Parallel
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
x = x[:, :-pad_shape] if pad_shape > 0 else x
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
|
||||
@@ -413,7 +413,7 @@ class BertEncoder(nn.Module):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class GeneralLoRALoader:
|
||||
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
|
||||
def get_name_dict(self, lora_state_dict):
|
||||
lora_name_dict = {}
|
||||
for key in lora_state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
if keys[0] == "diffusion_model":
|
||||
keys.pop(0)
|
||||
keys.pop(-1)
|
||||
target_name = ".".join(keys)
|
||||
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||
updated_num = 0
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
for name, module in model.named_modules():
|
||||
if name in lora_name_dict:
|
||||
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
||||
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||
state_dict = module.state_dict()
|
||||
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
||||
module.load_state_dict(state_dict)
|
||||
updated_num += 1
|
||||
print(f"{updated_num} tensors are updated by LoRA.")
|
||||
@@ -1,324 +0,0 @@
|
||||
import torch, math
|
||||
from . import GeneralLoRALoader
|
||||
from ..utils import ModelConfig
|
||||
from ..models.utils import load_state_dict
|
||||
from typing import Union
|
||||
|
||||
|
||||
class FluxLoRALoader(GeneralLoRALoader):
|
||||
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
|
||||
self.diffusers_rename_dict = {
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||
}
|
||||
|
||||
self.civitai_rename_dict = {
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||
}
|
||||
|
||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||
super().load(model, state_dict_lora, alpha)
|
||||
|
||||
|
||||
def convert_state_dict(self,state_dict):
|
||||
|
||||
def guess_block_id(name,model_resource):
|
||||
if model_resource == 'civitai':
|
||||
names = name.split("_")
|
||||
for i in names:
|
||||
if i.isdigit():
|
||||
return i, name.replace(f"_{i}_", "_blockid_")
|
||||
if model_resource == 'diffusers':
|
||||
names = name.split(".")
|
||||
for i in names:
|
||||
if i.isdigit():
|
||||
return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
|
||||
return None, None
|
||||
|
||||
def guess_resource(state_dict):
|
||||
for k in state_dict:
|
||||
if "lora_unet_" in k:
|
||||
return 'civitai'
|
||||
elif k.startswith("transformer."):
|
||||
return 'diffusers'
|
||||
else:
|
||||
None
|
||||
|
||||
model_resource = guess_resource(state_dict)
|
||||
if model_resource is None:
|
||||
return state_dict
|
||||
|
||||
rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
|
||||
def guess_alpha(state_dict):
|
||||
for name, param in state_dict.items():
|
||||
if ".alpha" in name:
|
||||
for suffix in [".lora_down.weight", ".lora_A.weight"]:
|
||||
name_ = name.replace(".alpha", suffix)
|
||||
if name_ in state_dict:
|
||||
lora_alpha = param.item() / state_dict[name_].shape[0]
|
||||
lora_alpha = math.sqrt(lora_alpha)
|
||||
return lora_alpha
|
||||
|
||||
return 1
|
||||
|
||||
alpha = guess_alpha(state_dict)
|
||||
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
block_id, source_name = guess_block_id(name,model_resource)
|
||||
if alpha != 1:
|
||||
param *= alpha
|
||||
if source_name in rename_dict:
|
||||
target_name = rename_dict[source_name]
|
||||
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
||||
state_dict_[target_name] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
|
||||
if model_resource == 'diffusers':
|
||||
for name in list(state_dict_.keys()):
|
||||
if "single_blocks." in name and ".a_to_q." in name:
|
||||
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||
if mlp is None:
|
||||
dim = 4
|
||||
if 'lora_A' in name:
|
||||
dim = 1
|
||||
mlp = torch.zeros(dim * state_dict_[name].shape[0],
|
||||
*state_dict_[name].shape[1:],
|
||||
dtype=state_dict_[name].dtype)
|
||||
else:
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||
if 'lora_A' in name:
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
elif 'lora_B' in name:
|
||||
d, r = state_dict_[name].shape
|
||||
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
|
||||
param[:d, :r] = state_dict_.pop(name)
|
||||
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
|
||||
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
|
||||
param[3*d:, 3*r:] = mlp
|
||||
else:
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||
state_dict_[name_] = param
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
concat_dim = 0
|
||||
if 'lora_A' in name:
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
elif 'lora_B' in name:
|
||||
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||
d, r = origin.shape
|
||||
# print(d, r)
|
||||
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
|
||||
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
|
||||
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
|
||||
else:
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
return state_dict_
|
||||
|
||||
|
||||
class LoraMerger(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
|
||||
self.bias = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||
|
||||
def forward(self, base_output, lora_outputs):
|
||||
norm_base_output = self.norm_base(base_output)
|
||||
norm_lora_outputs = self.norm_lora(lora_outputs)
|
||||
gate = self.activation(
|
||||
norm_base_output * self.weight_base \
|
||||
+ norm_lora_outputs * self.weight_lora \
|
||||
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
|
||||
)
|
||||
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
|
||||
return output
|
||||
|
||||
|
||||
class FluxLoraPatcher(torch.nn.Module):
|
||||
def __init__(self, lora_patterns=None):
|
||||
super().__init__()
|
||||
if lora_patterns is None:
|
||||
lora_patterns = self.default_lora_patterns()
|
||||
model_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||
model_dict[name.replace(".", "___")] = LoraMerger(dim)
|
||||
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||
|
||||
def default_lora_patterns(self):
|
||||
lora_patterns = []
|
||||
lora_dict = {
|
||||
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
|
||||
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
|
||||
}
|
||||
for i in range(19):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix]
|
||||
})
|
||||
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
|
||||
for i in range(38):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"single_blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix]
|
||||
})
|
||||
return lora_patterns
|
||||
|
||||
def forward(self, base_output, lora_outputs, name):
|
||||
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxLoraPatcherStateDictConverter()
|
||||
|
||||
|
||||
class FluxLoraPatcherStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
|
||||
class FluxLoRAFuser:
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
def Matrix_Decomposition_lowrank(self, A, k):
|
||||
U, S, V = torch.svd_lowrank(A.float(), q=k)
|
||||
S_k = torch.diag(S[:k])
|
||||
U_hat = U @ S_k
|
||||
return U_hat, V.t()
|
||||
|
||||
def LoRA_State_Dicts_Decomposition(self, lora_state_dicts=[], q=4):
|
||||
lora_1 = lora_state_dicts[0]
|
||||
state_dict_ = {}
|
||||
for k,v in lora_1.items():
|
||||
if 'lora_A.' in k:
|
||||
lora_B_name = k.replace('lora_A.', 'lora_B.')
|
||||
lora_B = lora_1[lora_B_name]
|
||||
weight = torch.mm(lora_B, v)
|
||||
for lora_dict in lora_state_dicts[1:]:
|
||||
lora_A_ = lora_dict[k]
|
||||
lora_B_ = lora_dict[lora_B_name]
|
||||
weight_ = torch.mm(lora_B_, lora_A_)
|
||||
weight += weight_
|
||||
new_B, new_A = self.Matrix_Decomposition_lowrank(weight, q)
|
||||
state_dict_[lora_B_name] = new_B.to(dtype=torch.bfloat16)
|
||||
state_dict_[k] = new_A.to(dtype=torch.bfloat16)
|
||||
return state_dict_
|
||||
|
||||
def __call__(self, lora_configs: list[Union[ModelConfig, str]]):
|
||||
loras = []
|
||||
loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
for lora_config in lora_configs:
|
||||
if isinstance(lora_config, str):
|
||||
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora_config.download_if_necessary()
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = loader.convert_state_dict(lora)
|
||||
loras.append(lora)
|
||||
lora = self.LoRA_State_Dicts_Decomposition(loras)
|
||||
return lora
|
||||
@@ -320,8 +320,6 @@ class FluxControlNetStateDictConverter:
|
||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
||||
elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
|
||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
|
||||
else:
|
||||
extra_kwargs = {}
|
||||
return state_dict_, extra_kwargs
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
||||
from einops import rearrange
|
||||
from .tiler import TileWorker
|
||||
from .utils import init_weights_on_device, hash_state_dict_keys
|
||||
from .utils import init_weights_on_device
|
||||
|
||||
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||
batch_size, num_tokens = hidden_states.shape[0:2]
|
||||
@@ -276,22 +276,20 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||
def __init__(self, disable_guidance_embedder=False):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||
|
||||
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||
|
||||
self.input_dim = input_dim
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
@@ -375,7 +373,8 @@ class FluxDiT(torch.nn.Module):
|
||||
return attention_mask
|
||||
|
||||
|
||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
|
||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
|
||||
repeat_dim = hidden_states.shape[1]
|
||||
max_masks = 0
|
||||
attention_mask = None
|
||||
prompt_embs = [prompt_emb]
|
||||
@@ -661,9 +660,6 @@ class FluxDiTStateDictConverter:
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
if hash_state_dict_keys(state_dict, with_shape=True) in ["3e6c61b0f9471135fc9c6d6a98e98b6d", "63c969fd37cce769a90aa781fbff5f81"]:
|
||||
dit_state_dict = {key.replace("pipe.dit.", ""): value for key, value in state_dict.items() if key.startswith('pipe.dit.')}
|
||||
return dit_state_dict
|
||||
rename_dict = {
|
||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
@@ -742,7 +738,5 @@ class FluxDiTStateDictConverter:
|
||||
pass
|
||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||
return state_dict_, {"disable_guidance_embedder": True}
|
||||
elif "blocks.8.attn.norm_k_a.weight" not in state_dict_:
|
||||
return state_dict_, {"input_dim": 196, "num_blocks": 8}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
@@ -104,7 +104,6 @@ class InfiniteYouImageProjector(nn.Module):
|
||||
def forward(self, x):
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
latents = latents.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
import torch
|
||||
from .sd_text_encoder import CLIPEncoderLayer
|
||||
|
||||
|
||||
class LoRALayerBlock(torch.nn.Module):
|
||||
def __init__(self, L, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
|
||||
self.layer_norm = torch.nn.LayerNorm(dim_out)
|
||||
|
||||
def forward(self, lora_A, lora_B):
|
||||
x = self.x @ lora_A.T @ lora_B.T
|
||||
x = self.layer_norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class LoRAEmbedder(torch.nn.Module):
|
||||
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
|
||||
super().__init__()
|
||||
if lora_patterns is None:
|
||||
lora_patterns = self.default_lora_patterns()
|
||||
|
||||
model_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1])
|
||||
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||
|
||||
proj_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
layer_type, dim = lora_pattern["type"], lora_pattern["dim"]
|
||||
if layer_type not in proj_dict:
|
||||
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim)
|
||||
self.proj_dict = torch.nn.ModuleDict(proj_dict)
|
||||
|
||||
self.lora_patterns = lora_patterns
|
||||
|
||||
|
||||
def default_lora_patterns(self):
|
||||
lora_patterns = []
|
||||
lora_dict = {
|
||||
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
|
||||
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
|
||||
}
|
||||
for i in range(19):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix],
|
||||
"type": suffix,
|
||||
})
|
||||
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
|
||||
for i in range(38):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"single_blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix],
|
||||
"type": suffix,
|
||||
})
|
||||
return lora_patterns
|
||||
|
||||
def forward(self, lora):
|
||||
lora_emb = []
|
||||
for lora_pattern in self.lora_patterns:
|
||||
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||
lora_A = lora[name + ".lora_A.default.weight"]
|
||||
lora_B = lora[name + ".lora_B.default.weight"]
|
||||
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||
lora_emb.append(lora_out)
|
||||
lora_emb = torch.concat(lora_emb, dim=1)
|
||||
return lora_emb
|
||||
|
||||
|
||||
class FluxLoRAEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1):
|
||||
super().__init__()
|
||||
self.num_embeds_per_lora = num_embeds_per_lora
|
||||
# embedder
|
||||
self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim)
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)])
|
||||
|
||||
# special embedding
|
||||
self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim))
|
||||
self.num_special_embeds = num_special_embeds
|
||||
|
||||
# final layer
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
self.final_linear = torch.nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
def forward(self, lora):
|
||||
lora_embeds = self.embedder(lora)
|
||||
special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device)
|
||||
embeds = torch.concat([special_embeds, lora_embeds], dim=1)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds)
|
||||
embeds = embeds[:, :self.num_special_embeds]
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
embeds = self.final_linear(embeds)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxLoRAEncoderStateDictConverter()
|
||||
|
||||
|
||||
class FluxLoRAEncoderStateDictConverter:
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -1,60 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
|
||||
|
||||
class MultiValueEncoder(torch.nn.Module):
|
||||
def __init__(self, encoders=()):
|
||||
super().__init__()
|
||||
self.encoders = torch.nn.ModuleList(encoders)
|
||||
|
||||
def __call__(self, values, dtype):
|
||||
emb = []
|
||||
for encoder, value in zip(self.encoders, values):
|
||||
if value is not None:
|
||||
value = value.unsqueeze(0)
|
||||
emb.append(encoder(value, dtype))
|
||||
emb = torch.concat(emb, dim=0)
|
||||
return emb
|
||||
|
||||
|
||||
class SingleValueEncoder(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
|
||||
super().__init__()
|
||||
self.prefer_len = prefer_len
|
||||
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||
self.prefer_value_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
self.positional_embedding = torch.nn.Parameter(
|
||||
torch.randn(self.prefer_len, dim_out)
|
||||
)
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
last_linear = self.prefer_value_embedder[-1]
|
||||
torch.nn.init.zeros_(last_linear.weight)
|
||||
torch.nn.init.zeros_(last_linear.bias)
|
||||
|
||||
def forward(self, value, dtype):
|
||||
value = value * 1000
|
||||
emb = self.prefer_proj(value).to(dtype)
|
||||
emb = self.prefer_value_embedder(emb).squeeze(0)
|
||||
base_embeddings = emb.expand(self.prefer_len, -1)
|
||||
positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
|
||||
learned_embeddings = base_embeddings + positional_embedding
|
||||
return learned_embeddings
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SingleValueEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SingleValueEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -1373,7 +1373,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
if not has_default_max_length:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
|
||||
@@ -277,7 +277,7 @@ class FluxLoRAConverter:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, alpha=None):
|
||||
def align_to_opensource_format(state_dict, alpha=1.0):
|
||||
prefix_rename_dict = {
|
||||
"single_blocks": "lora_unet_single_blocks",
|
||||
"blocks": "lora_unet_double_blocks",
|
||||
@@ -316,8 +316,7 @@ class FluxLoRAConverter:
|
||||
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
|
||||
state_dict_[rename] = param
|
||||
if rename.endswith("lora_up.weight"):
|
||||
lora_alpha = alpha if alpha is not None else param.shape[-1]
|
||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((lora_alpha,))[0]
|
||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
|
||||
return state_dict_
|
||||
|
||||
@staticmethod
|
||||
@@ -366,37 +365,7 @@ class FluxLoRAConverter:
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
class WanLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, **kwargs):
|
||||
state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||
state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
|
||||
class QwenImageLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, **kwargs):
|
||||
state_dict = {name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||
state_dict = {name.replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_lora_loaders():
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||
|
||||
@@ -426,7 +426,7 @@ class ModelManager:
|
||||
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None):
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
||||
fetched_models = []
|
||||
fetched_model_paths = []
|
||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||
@@ -440,25 +440,12 @@ class ModelManager:
|
||||
return None
|
||||
if len(fetched_models) == 1:
|
||||
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
||||
model = fetched_models[0]
|
||||
path = fetched_model_paths[0]
|
||||
else:
|
||||
if index is None:
|
||||
model = fetched_models[0]
|
||||
path = fetched_model_paths[0]
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||
elif isinstance(index, int):
|
||||
model = fetched_models[:index]
|
||||
path = fetched_model_paths[:index]
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.")
|
||||
else:
|
||||
model = fetched_models
|
||||
path = fetched_model_paths
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.")
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||
if require_model_path:
|
||||
return model, path
|
||||
return fetched_models[0], fetched_model_paths[0]
|
||||
else:
|
||||
return model
|
||||
return fetched_models[0]
|
||||
|
||||
|
||||
def to(self, device):
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class NexusGenAutoregressiveModel(torch.nn.Module):
|
||||
def __init__(self, max_length=1024, max_pixels=262640):
|
||||
super(NexusGenAutoregressiveModel, self).__init__()
|
||||
from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration
|
||||
from transformers import Qwen2_5_VLConfig
|
||||
self.max_length = max_length
|
||||
self.max_pixels = max_pixels
|
||||
model_config = Qwen2_5_VLConfig(**{
|
||||
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"auto_map": {
|
||||
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
||||
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
||||
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
||||
},
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"image_token_id": 151655,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"max_position_embeddings": 128000,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"pad_token_id": 151643,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [
|
||||
16,
|
||||
24,
|
||||
24
|
||||
],
|
||||
"rope_type": "default",
|
||||
"type": "default"
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": 32768,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.49.0",
|
||||
"use_cache": False,
|
||||
"use_sliding_window": False,
|
||||
"video_token_id": 151656,
|
||||
"vision_config": {
|
||||
"hidden_size": 1280,
|
||||
"in_chans": 3,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"spatial_patch_size": 14,
|
||||
"tokens_per_second": 2,
|
||||
"torch_dtype": "bfloat16"
|
||||
},
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 152064
|
||||
})
|
||||
self.model = Qwen2_5_VLForConditionalGeneration(model_config)
|
||||
self.processor = None
|
||||
|
||||
|
||||
def load_processor(self, path):
|
||||
from .nexus_gen_ar_model import Qwen2_5_VLProcessor
|
||||
self.processor = Qwen2_5_VLProcessor.from_pretrained(path)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return NexusGenAutoregressiveModelStateDictConverter()
|
||||
|
||||
def bound_image(self, image, max_pixels=262640):
|
||||
from qwen_vl_utils import smart_resize
|
||||
resized_height, resized_width = smart_resize(
|
||||
image.height,
|
||||
image.width,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
return image.resize((resized_width, resized_height))
|
||||
|
||||
def get_editing_msg(self, instruction):
|
||||
if '<image>' not in instruction:
|
||||
instruction = '<image> ' + instruction
|
||||
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: <image>"}]
|
||||
return messages
|
||||
|
||||
def get_generation_msg(self, instruction):
|
||||
instruction = "Generate an image according to the following description: {}".format(instruction)
|
||||
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: <image>"}]
|
||||
return messages
|
||||
|
||||
def forward(self, instruction, ref_image=None, num_img_tokens=81):
|
||||
"""
|
||||
Generate target embeddings for the given instruction and reference image.
|
||||
"""
|
||||
if ref_image is not None:
|
||||
messages = self.get_editing_msg(instruction)
|
||||
images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
||||
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
||||
else:
|
||||
messages = self.get_generation_msg(instruction)
|
||||
images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
||||
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
||||
|
||||
return output_image_embeddings
|
||||
|
||||
def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81):
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
||||
text = text.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=images,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
input_embeds = model.model.embed_tokens(inputs['input_ids'])
|
||||
image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
|
||||
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
|
||||
input_image_embeds = image_embeds[:-num_img_tokens]
|
||||
|
||||
image_mask = inputs['input_ids'] == model.config.image_token_id
|
||||
indices = image_mask.cumsum(dim=1)
|
||||
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
|
||||
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
|
||||
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
|
||||
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
|
||||
|
||||
image_prefill_embeds = model.image_prefill_embeds(
|
||||
torch.arange(81, device=model.device).long()
|
||||
)
|
||||
input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)
|
||||
|
||||
position_ids, _ = model.get_rope_index(
|
||||
inputs['input_ids'],
|
||||
inputs['image_grid_thw'],
|
||||
attention_mask=inputs['attention_mask'])
|
||||
position_ids = position_ids.contiguous()
|
||||
outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
|
||||
output_image_embeddings = outputs.image_embeddings[:, :-1, :]
|
||||
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
|
||||
return output_image_embeddings, input_image_embeds, inputs['image_grid_thw']
|
||||
|
||||
|
||||
class NexusGenAutoregressiveModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {"model." + key: value for key, value in state_dict.items()}
|
||||
return state_dict
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,417 +0,0 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
mrope_section = mrope_section * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
from transformers.modeling_rope_utils import _compute_default_rope_parameters
|
||||
self.rope_init_fn = _compute_default_rope_parameters
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
# Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids
|
||||
# So we expand the inv_freq to shape (3, ...)
|
||||
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
||||
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class Qwen2_5_VLAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.rope_scaling = config.rope_scaling
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
# Fix precision issues in Qwen2-VL float16 inference
|
||||
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
||||
if query_states.dtype == torch.float16:
|
||||
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
from transformers.activations import ACT2FN
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class Qwen2RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Qwen2RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class Qwen2_5_VLDecoderLayer(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
|
||||
|
||||
self.mlp = Qwen2MLP(config)
|
||||
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NexusGenImageEmbeddingMerger(nn.Module):
|
||||
def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'):
|
||||
super().__init__()
|
||||
from transformers import Qwen2_5_VLConfig
|
||||
from transformers.activations import ACT2FN
|
||||
config = Qwen2_5_VLConfig(**{
|
||||
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"auto_map": {
|
||||
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
||||
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
||||
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
||||
},
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"image_token_id": 151655,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"max_position_embeddings": 128000,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"pad_token_id": 151643,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [
|
||||
16,
|
||||
24,
|
||||
24
|
||||
],
|
||||
"rope_type": "default",
|
||||
"type": "default"
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": 32768,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.49.0",
|
||||
"use_cache": False,
|
||||
"use_sliding_window": False,
|
||||
"video_token_id": 151656,
|
||||
"vision_config": {
|
||||
"hidden_size": 1280,
|
||||
"in_chans": 3,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"spatial_patch_size": 14,
|
||||
"tokens_per_second": 2,
|
||||
"torch_dtype": "bfloat16"
|
||||
},
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 152064
|
||||
})
|
||||
self.config = config
|
||||
self.num_layers = num_layers
|
||||
self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)])
|
||||
self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps),
|
||||
nn.Linear(config.hidden_size, out_channel * expand_ratio),
|
||||
Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps),
|
||||
ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel),
|
||||
Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps))
|
||||
self.base_grid = torch.tensor([[1, 72, 72]], device=device)
|
||||
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device)
|
||||
|
||||
def get_position_ids(self, image_grid_thw):
|
||||
"""
|
||||
Generates position ids for the input embeddings grid.
|
||||
modified from the qwen2_vl mrope.
|
||||
"""
|
||||
batch_size = image_grid_thw.shape[0]
|
||||
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
||||
t, h, w = (
|
||||
image_grid_thw[0][0],
|
||||
image_grid_thw[0][1],
|
||||
image_grid_thw[0][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t.item(),
|
||||
h.item() // spatial_merge_size,
|
||||
w.item() // spatial_merge_size,
|
||||
)
|
||||
scale_h = self.base_grid[0][1].item() / h.item()
|
||||
scale_w = self.base_grid[0][2].item() / w.item()
|
||||
|
||||
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
||||
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
|
||||
time_tensor = expanded_range * self.config.vision_config.tokens_per_second
|
||||
t_index = time_tensor.long().flatten().to(image_grid_thw.device)
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w
|
||||
# 3, B, L
|
||||
position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2)
|
||||
return position_ids
|
||||
|
||||
def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None):
|
||||
position_ids = self.get_position_ids(embeds_grid)
|
||||
hidden_states = embeds
|
||||
if ref_embeds is not None:
|
||||
position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid)
|
||||
position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1)
|
||||
hidden_states = torch.cat((embeds, ref_embeds), dim=1)
|
||||
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, position_embeddings)
|
||||
|
||||
hidden_states = self.projector(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return NexusGenMergerStateDictConverter()
|
||||
|
||||
|
||||
class NexusGenMergerStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')}
|
||||
return merger_state_dict
|
||||
|
||||
|
||||
class NexusGenAdapter(nn.Module):
|
||||
"""
|
||||
Adapter for Nexus-Gen generation decoder.
|
||||
"""
|
||||
def __init__(self, input_dim=3584, output_dim=4096):
|
||||
super(NexusGenAdapter, self).__init__()
|
||||
self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim),
|
||||
nn.LayerNorm(output_dim), nn.ReLU(),
|
||||
nn.Linear(output_dim, output_dim),
|
||||
nn.LayerNorm(output_dim))
|
||||
|
||||
def forward(self, x):
|
||||
return self.adapter(x)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return NexusGenAdapterStateDictConverter()
|
||||
|
||||
|
||||
class NexusGenAdapterStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')}
|
||||
return adapter_state_dict
|
||||
@@ -1,63 +0,0 @@
|
||||
from .qwen_image_dit import QwenImageTransformerBlock, AdaLayerNorm, TimestepEmbeddings
|
||||
from einops import rearrange
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class QwenImageAccelerateAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj_latents_in = torch.nn.Linear(64, 3072)
|
||||
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
|
||||
self.transformer_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
QwenImageTransformerBlock(
|
||||
dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = AdaLayerNorm(3072, single=True)
|
||||
self.proj_out = torch.nn.Linear(3072, 64)
|
||||
self.proj_latents_out = torch.nn.Linear(64, 64)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
image=None,
|
||||
text=None,
|
||||
image_rotary_emb=None,
|
||||
timestep=None,
|
||||
):
|
||||
latents = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
image = image + self.proj_latents_in(latents)
|
||||
conditioning = self.time_text_embed(timestep, image.dtype)
|
||||
for block in self.transformer_blocks:
|
||||
text, image = block(
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
image = self.norm_out(image, conditioning)
|
||||
image = self.proj_out(image)
|
||||
image = image + self.proj_latents_out(latents)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageAccelerateAdapterStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageAccelerateAdapterStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -1,458 +0,0 @@
|
||||
import torch, math
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Optional, Union, List
|
||||
from einops import rearrange
|
||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||
from .flux_dit import AdaLayerNorm
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
|
||||
|
||||
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
|
||||
if FLASH_ATTN_3_AVAILABLE and attention_mask is None:
|
||||
if not enable_fp8_attention:
|
||||
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
|
||||
x = flash_attn_interface.flash_attn_func(q, k, v)
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
else:
|
||||
origin_dtype = q.dtype
|
||||
q_std, k_std, v_std = q.std(), k.std(), v.std()
|
||||
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
|
||||
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
|
||||
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
x = x.to(origin_dtype) * v_std
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
else:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
return x
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def apply_rotary_emb_qwen(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
|
||||
):
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
class QwenEmbedRope(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(1024)
|
||||
neg_index = torch.arange(1024).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat([
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
self.neg_freqs = torch.cat([
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
self.rope_cache = {}
|
||||
self.scale_rope = scale_rope
|
||||
|
||||
def rope_params(self, index, dim, theta=10000):
|
||||
"""
|
||||
Args:
|
||||
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
index,
|
||||
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
|
||||
)
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
frame, height, width = video_fhw
|
||||
rope_key = f"{frame}_{height}_{width}"
|
||||
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[
|
||||
freqs_neg[1][-(height - height//2):],
|
||||
freqs_pos[1][:height//2]
|
||||
],
|
||||
dim=0
|
||||
)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat(
|
||||
[
|
||||
freqs_neg[2][-(width - width//2):],
|
||||
freqs_pos[2][:width//2]
|
||||
],
|
||||
dim=0
|
||||
)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||
vid_freqs = self.rope_cache[rope_key]
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2)
|
||||
else:
|
||||
max_vid_index = max(height, width)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index: max_vid_index + max_len, ...]
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
class QwenFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * 4)
|
||||
self.net = nn.ModuleList([])
|
||||
self.net.append(ApproximateGELU(dim, inner_dim))
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
self.net.append(nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class QwenDoubleStreamAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_a,
|
||||
dim_b,
|
||||
num_heads,
|
||||
head_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = nn.Linear(dim_a, dim_a)
|
||||
self.to_k = nn.Linear(dim_a, dim_a)
|
||||
self.to_v = nn.Linear(dim_a, dim_a)
|
||||
self.norm_q = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k = RMSNorm(head_dim, eps=1e-6)
|
||||
|
||||
self.add_q_proj = nn.Linear(dim_b, dim_b)
|
||||
self.add_k_proj = nn.Linear(dim_b, dim_b)
|
||||
self.add_v_proj = nn.Linear(dim_b, dim_b)
|
||||
self.norm_added_q = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_added_k = RMSNorm(head_dim, eps=1e-6)
|
||||
|
||||
self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a))
|
||||
self.to_add_out = nn.Linear(dim_b, dim_b)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image: torch.FloatTensor,
|
||||
text: torch.FloatTensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
enable_fp8_attention: bool = False,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
|
||||
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
|
||||
seq_txt = txt_q.shape[1]
|
||||
|
||||
img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
|
||||
txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
|
||||
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
|
||||
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
img_freqs, txt_freqs = image_rotary_emb
|
||||
img_q = apply_rotary_emb_qwen(img_q, img_freqs)
|
||||
img_k = apply_rotary_emb_qwen(img_k, img_freqs)
|
||||
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
|
||||
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
|
||||
|
||||
joint_q = torch.cat([txt_q, img_q], dim=2)
|
||||
joint_k = torch.cat([txt_k, img_k], dim=2)
|
||||
joint_v = torch.cat([txt_v, img_v], dim=2)
|
||||
|
||||
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
|
||||
|
||||
txt_attn_output = joint_attn_out[:, :seq_txt, :]
|
||||
img_attn_output = joint_attn_out[:, seq_txt:, :]
|
||||
|
||||
img_attn_output = self.to_out(img_attn_output)
|
||||
txt_attn_output = self.to_add_out(txt_attn_output)
|
||||
|
||||
return img_attn_output, txt_attn_output
|
||||
|
||||
|
||||
class QwenImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 6 * dim),
|
||||
)
|
||||
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.attn = QwenDoubleStreamAttention(
|
||||
dim_a=dim,
|
||||
dim_b=dim,
|
||||
num_heads=num_attention_heads,
|
||||
head_dim=attention_head_dim,
|
||||
)
|
||||
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
||||
|
||||
self.txt_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 6 * dim, bias=True),
|
||||
)
|
||||
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
||||
|
||||
def _modulate(self, x, mod_params):
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
enable_fp8_attention = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||
|
||||
img_normed = self.img_norm1(image)
|
||||
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
|
||||
|
||||
txt_normed = self.txt_norm1(text)
|
||||
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
|
||||
|
||||
img_attn_out, txt_attn_out = self.attn(
|
||||
image=img_modulated,
|
||||
text=txt_modulated,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
enable_fp8_attention=enable_fp8_attention,
|
||||
)
|
||||
|
||||
image = image + img_gate * img_attn_out
|
||||
text = text + txt_gate * txt_attn_out
|
||||
|
||||
img_normed_2 = self.img_norm2(image)
|
||||
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
|
||||
|
||||
txt_normed_2 = self.txt_norm2(text)
|
||||
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
|
||||
|
||||
img_mlp_out = self.img_mlp(img_modulated_2)
|
||||
txt_mlp_out = self.txt_mlp(txt_modulated_2)
|
||||
|
||||
image = image + img_gate_2 * img_mlp_out
|
||||
text = text + txt_gate_2 * txt_mlp_out
|
||||
|
||||
return text, image
|
||||
|
||||
|
||||
class QwenImageDiT(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
||||
|
||||
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
|
||||
self.txt_norm = RMSNorm(3584, eps=1e-6)
|
||||
|
||||
self.img_in = nn.Linear(64, 3072)
|
||||
self.txt_in = nn.Linear(3584, 3072)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
QwenImageTransformerBlock(
|
||||
dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = AdaLayerNorm(3072, single=True)
|
||||
self.proj_out = nn.Linear(3072, 64)
|
||||
|
||||
|
||||
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes):
|
||||
# prompt_emb
|
||||
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
||||
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
|
||||
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
||||
|
||||
# image_rotary_emb
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask]
|
||||
entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
|
||||
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
||||
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
||||
|
||||
# attention_mask
|
||||
repeat_dim = latents.shape[1]
|
||||
max_masks = entity_masks.shape[1]
|
||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype)
|
||||
entity_masks = entity_masks + [global_mask]
|
||||
|
||||
N = len(entity_masks)
|
||||
batch_size = entity_masks[0].shape[0]
|
||||
seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
|
||||
total_seq_len = sum(seq_lens) + image.shape[1]
|
||||
patched_masks = []
|
||||
for i in range(N):
|
||||
patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
patched_masks.append(patched_mask)
|
||||
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||
|
||||
# prompt-image attention mask
|
||||
image_start = sum(seq_lens)
|
||||
image_end = total_seq_len
|
||||
cumsum = [0]
|
||||
for length in seq_lens:
|
||||
cumsum.append(cumsum[-1] + length)
|
||||
for i in range(N):
|
||||
prompt_start = cumsum[i]
|
||||
prompt_end = cumsum[i+1]
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
||||
# prompt update with image
|
||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||
# image update with prompt
|
||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||
# prompt-prompt attention mask, let the prompt tokens not attend to each other
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if i == j:
|
||||
continue
|
||||
start_i, end_i = cumsum[i], cumsum[i+1]
|
||||
start_j, end_j = cumsum[j], cumsum[j+1]
|
||||
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
||||
|
||||
attention_mask = attention_mask.float()
|
||||
attention_mask[attention_mask == 0] = float('-inf')
|
||||
attention_mask[attention_mask == 1] = 0
|
||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
||||
|
||||
return all_prompt_emb, image_rotary_emb, attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
prompt_emb_mask=None,
|
||||
height=None,
|
||||
width=None,
|
||||
):
|
||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2)
|
||||
image = self.img_in(image)
|
||||
text = self.txt_in(self.txt_norm(prompt_emb))
|
||||
|
||||
conditioning = self.time_text_embed(timestep, image.dtype)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
text, image = block(
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
image = self.norm_out(image, conditioning)
|
||||
image = self.proj_out(image)
|
||||
|
||||
latents = rearrange(image, "B (H W) (P Q C) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageDiTStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -1,255 +0,0 @@
|
||||
from transformers import Qwen2_5_VLModel
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class QwenImageTextEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import Qwen2_5_VLConfig
|
||||
config = Qwen2_5_VLConfig(**{
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"image_token_id": 151655,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"max_position_embeddings": 128000,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [
|
||||
16,
|
||||
24,
|
||||
24
|
||||
],
|
||||
"rope_type": "default",
|
||||
"type": "default"
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": 32768,
|
||||
"text_config": {
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"image_token_id": None,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"layer_types": [
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention"
|
||||
],
|
||||
"max_position_embeddings": 128000,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2_5_vl_text",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [
|
||||
16,
|
||||
24,
|
||||
24
|
||||
],
|
||||
"rope_type": "default",
|
||||
"type": "default"
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": None,
|
||||
"torch_dtype": "float32",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"video_token_id": None,
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 152064
|
||||
},
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.54.0",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"video_token_id": 151656,
|
||||
"vision_config": {
|
||||
"depth": 32,
|
||||
"fullatt_block_indexes": [
|
||||
7,
|
||||
15,
|
||||
23,
|
||||
31
|
||||
],
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 1280,
|
||||
"in_channels": 3,
|
||||
"in_chans": 3,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3420,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"num_heads": 16,
|
||||
"out_hidden_size": 3584,
|
||||
"patch_size": 14,
|
||||
"spatial_merge_size": 2,
|
||||
"spatial_patch_size": 14,
|
||||
"temporal_patch_size": 2,
|
||||
"tokens_per_second": 2,
|
||||
"torch_dtype": "float32",
|
||||
"window_size": 112
|
||||
},
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 152064
|
||||
})
|
||||
self.model = Qwen2_5_VLModel(config)
|
||||
self.lm_head = torch.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each video in LLM.
|
||||
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
||||
The rope index difference between sequence length and multimodal rope.
|
||||
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
|
||||
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
|
||||
>>> messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||
```"""
|
||||
|
||||
output_attentions = False
|
||||
output_hidden_states = True
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
return outputs.hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageTextEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("visual."):
|
||||
k = "model." + k
|
||||
elif k.startswith("model."):
|
||||
k = k.replace("model.", "model.language_model.")
|
||||
state_dict_[k] = v
|
||||
return state_dict_
|
||||
@@ -1,736 +0,0 @@
|
||||
import torch
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from torch import nn
|
||||
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
class QwenImageCausalConv3d(torch.nn.Conv3d):
|
||||
r"""
|
||||
A custom 3D causal convolution layer with feature caching support.
|
||||
|
||||
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
|
||||
caching for efficient inference.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
||||
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
# Set up causal padding
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = torch.nn.functional.pad(x, padding)
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
|
||||
class QwenImageRMS_norm(nn.Module):
|
||||
r"""
|
||||
A custom RMS normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The number of dimensions to normalize over.
|
||||
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
||||
Default is True.
|
||||
images (bool, optional): Whether the input represents image data. Default is True.
|
||||
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
|
||||
class QwenImageResidualBlock(nn.Module):
|
||||
r"""
|
||||
A custom residual block module.
|
||||
|
||||
Args:
|
||||
in_dim (int): Number of input channels.
|
||||
out_dim (int): Number of output channels.
|
||||
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
|
||||
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
dropout: float = 0.0,
|
||||
non_linearity: str = "silu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.nonlinearity = torch.nn.SiLU()
|
||||
|
||||
# layers
|
||||
self.norm1 = QwenImageRMS_norm(in_dim, images=False)
|
||||
self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
|
||||
self.norm2 = QwenImageRMS_norm(out_dim, images=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||
self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# Apply shortcut connection
|
||||
h = self.conv_shortcut(x)
|
||||
|
||||
# First normalization and activation
|
||||
x = self.norm1(x)
|
||||
x = self.nonlinearity(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
# Second normalization and activation
|
||||
x = self.norm2(x)
|
||||
x = self.nonlinearity(x)
|
||||
|
||||
# Dropout
|
||||
x = self.dropout(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.conv2(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv2(x)
|
||||
|
||||
# Add residual connection
|
||||
return x + h
|
||||
|
||||
|
||||
|
||||
class QwenImageAttentionBlock(nn.Module):
|
||||
r"""
|
||||
Causal self-attention with a single head.
|
||||
|
||||
Args:
|
||||
dim (int): The number of channels in the input tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = QwenImageRMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
batch_size, channels, time, height, width = x.size()
|
||||
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
|
||||
x = self.norm(x)
|
||||
|
||||
# compute query, key, value
|
||||
qkv = self.to_qkv(x)
|
||||
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
|
||||
qkv = qkv.permute(0, 1, 3, 2).contiguous()
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
# apply attention
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
|
||||
|
||||
# output projection
|
||||
x = self.proj(x)
|
||||
|
||||
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
|
||||
x = x.view(batch_size, time, channels, height, width)
|
||||
x = x.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return x + identity
|
||||
|
||||
|
||||
|
||||
class QwenImageUpsample(nn.Upsample):
|
||||
r"""
|
||||
Perform upsampling while ensuring the output tensor has the same data type as the input.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor to be upsampled.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Upsampled tensor with the same data type as the input.
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
|
||||
class QwenImageResample(nn.Module):
|
||||
r"""
|
||||
A custom resampling module for 2D and 3D data.
|
||||
|
||||
Args:
|
||||
dim (int): The number of input/output channels.
|
||||
mode (str): The resampling mode. Must be one of:
|
||||
- 'none': No resampling (identity operation).
|
||||
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
|
||||
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
|
||||
- 'downsample2d': 2D downsampling with zero-padding and convolution.
|
||||
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, mode: str) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == "upsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
elif mode == "upsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
elif mode == "downsample2d":
|
||||
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == "downsample3d":
|
||||
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == "upsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = "Rep"
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat(
|
||||
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
||||
)
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
|
||||
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
||||
if feat_cache[idx] == "Rep":
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.resample(x)
|
||||
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
|
||||
|
||||
if self.mode == "downsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class QwenImageMidBlock(nn.Module):
|
||||
"""
|
||||
Middle block for WanVAE encoder and decoder.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input/output channels.
|
||||
dropout (float): Dropout rate.
|
||||
non_linearity (str): Type of non-linearity to use.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# Create the components
|
||||
resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
|
||||
attentions = []
|
||||
for _ in range(num_layers):
|
||||
attentions.append(QwenImageAttentionBlock(dim))
|
||||
resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# First residual block
|
||||
x = self.resnets[0](x, feat_cache, feat_idx)
|
||||
|
||||
# Process through attention and residual blocks
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
x = attn(x)
|
||||
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class QwenImageEncoder3d(nn.Module):
|
||||
r"""
|
||||
A 3D encoder module.
|
||||
|
||||
Args:
|
||||
dim (int): The base number of channels in the first layer.
|
||||
z_dim (int): The dimensionality of the latent space.
|
||||
dim_mult (list of int): Multipliers for the number of channels in each block.
|
||||
num_res_blocks (int): Number of residual blocks in each block.
|
||||
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
||||
temperal_downsample (list of bool): Whether to downsample temporally in each block.
|
||||
dropout (float): Dropout rate for the dropout layers.
|
||||
non_linearity (str): Type of non-linearity to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.nonlinearity = torch.nn.SiLU()
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
self.down_blocks = torch.nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
self.down_blocks.append(QwenImageAttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
||||
self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
||||
|
||||
# output blocks
|
||||
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
||||
self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_in(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_in(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.down_blocks:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
x = self.nonlinearity(x)
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_out(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class QwenImageUpBlock(nn.Module):
|
||||
"""
|
||||
A block that handles upsampling for the WanVAE decoder.
|
||||
|
||||
Args:
|
||||
in_dim (int): Input dimension
|
||||
out_dim (int): Output dimension
|
||||
num_res_blocks (int): Number of residual blocks
|
||||
dropout (float): Dropout rate
|
||||
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
|
||||
non_linearity (str): Type of non-linearity to use
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
num_res_blocks: int,
|
||||
dropout: float = 0.0,
|
||||
upsample_mode: Optional[str] = None,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# Create layers list
|
||||
resnets = []
|
||||
# Add residual blocks and attention if needed
|
||||
current_dim = in_dim
|
||||
for _ in range(num_res_blocks + 1):
|
||||
resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
||||
current_dim = out_dim
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
# Add upsampling layer if needed
|
||||
self.upsamplers = None
|
||||
if upsample_mode is not None:
|
||||
self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
"""
|
||||
Forward pass through the upsampling block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor
|
||||
feat_cache (list, optional): Feature cache for causal convolutions
|
||||
feat_idx (list, optional): Feature index for cache management
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor
|
||||
"""
|
||||
for resnet in self.resnets:
|
||||
if feat_cache is not None:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = resnet(x)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
if feat_cache is not None:
|
||||
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = self.upsamplers[0](x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class QwenImageDecoder3d(nn.Module):
|
||||
r"""
|
||||
A 3D decoder module.
|
||||
|
||||
Args:
|
||||
dim (int): The base number of channels in the first layer.
|
||||
z_dim (int): The dimensionality of the latent space.
|
||||
dim_mult (list of int): Multipliers for the number of channels in each block.
|
||||
num_res_blocks (int): Number of residual blocks in each block.
|
||||
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
||||
temperal_upsample (list of bool): Whether to upsample temporally in each block.
|
||||
dropout (float): Dropout rate for the dropout layers.
|
||||
non_linearity (str): Type of non-linearity to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
self.nonlinearity = torch.nn.SiLU()
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
|
||||
|
||||
# upsample blocks
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i > 0:
|
||||
in_dim = in_dim // 2
|
||||
|
||||
# Determine if we need upsampling
|
||||
upsample_mode = None
|
||||
if i != len(dim_mult) - 1:
|
||||
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
||||
|
||||
# Create and add the upsampling block
|
||||
up_block = QwenImageUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
upsample_mode=upsample_mode,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# Update scale for next iteration
|
||||
if upsample_mode is not None:
|
||||
scale *= 2.0
|
||||
|
||||
# output blocks
|
||||
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
||||
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_in(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_in(x)
|
||||
|
||||
## middle
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
|
||||
## upsamples
|
||||
for up_block in self.up_blocks:
|
||||
x = up_block(x, feat_cache, feat_idx)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
x = self.nonlinearity(x)
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_out(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class QwenImageVAE(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
dropout: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.z_dim = z_dim
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
self.encoder = QwenImageEncoder3d(
|
||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
||||
)
|
||||
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
|
||||
|
||||
self.decoder = QwenImageDecoder3d(
|
||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
||||
)
|
||||
|
||||
mean = [
|
||||
-0.7571,
|
||||
-0.7089,
|
||||
-0.9113,
|
||||
0.1075,
|
||||
-0.1745,
|
||||
0.9653,
|
||||
-0.1517,
|
||||
1.5508,
|
||||
0.4134,
|
||||
-0.0715,
|
||||
0.5517,
|
||||
-0.3632,
|
||||
-0.1922,
|
||||
-0.9497,
|
||||
0.2503,
|
||||
-0.2921,
|
||||
]
|
||||
std = [
|
||||
2.8184,
|
||||
1.4541,
|
||||
2.3275,
|
||||
2.6558,
|
||||
1.2196,
|
||||
1.7708,
|
||||
2.6052,
|
||||
2.0743,
|
||||
3.2687,
|
||||
2.1526,
|
||||
2.8652,
|
||||
1.5579,
|
||||
1.6382,
|
||||
1.1253,
|
||||
2.8251,
|
||||
1.9160,
|
||||
]
|
||||
self.mean = torch.tensor(mean).view(1, 16, 1, 1, 1)
|
||||
self.std = 1 / torch.tensor(std).view(1, 16, 1, 1, 1)
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
x = x.unsqueeze(2)
|
||||
x = self.encoder(x)
|
||||
x = self.quant_conv(x)
|
||||
x = x[:, :16]
|
||||
mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device)
|
||||
x = (x - mean) * std
|
||||
x = x.squeeze(2)
|
||||
return x
|
||||
|
||||
def decode(self, x, **kwargs):
|
||||
x = x.unsqueeze(2)
|
||||
mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device)
|
||||
x = x / std + mean
|
||||
x = self.post_quant_conv(x)
|
||||
x = self.decoder(x)
|
||||
x = x.squeeze(2)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageVAEStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageVAEStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
@@ -1,168 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Qwen25VL_7b_Embedder(torch.nn.Module):
|
||||
def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
||||
super(Qwen25VL_7b_Embedder, self).__init__()
|
||||
self.max_length = max_length
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dtype,
|
||||
).to(torch.cuda.current_device())
|
||||
|
||||
self.model.requires_grad_(False)
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
|
||||
)
|
||||
|
||||
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
|
||||
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
|
||||
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
|
||||
Here are examples of how to transform or refine prompts:
|
||||
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
|
||||
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
|
||||
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
|
||||
User Prompt:'''
|
||||
|
||||
self.prefix = Qwen25VL_7b_PREFIX
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path, torch_dtype=torch.bfloat16, device="cuda"):
|
||||
return Qwen25VL_7b_Embedder(path, dtype=torch_dtype, device=device)
|
||||
|
||||
def forward(self, caption, ref_images):
|
||||
text_list = caption
|
||||
embs = torch.zeros(
|
||||
len(text_list),
|
||||
self.max_length,
|
||||
self.model.config.hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
hidden_states = torch.zeros(
|
||||
len(text_list),
|
||||
self.max_length,
|
||||
self.model.config.hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
masks = torch.zeros(
|
||||
len(text_list),
|
||||
self.max_length,
|
||||
dtype=torch.long,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
input_ids_list = []
|
||||
attention_mask_list = []
|
||||
emb_list = []
|
||||
|
||||
def split_string(s):
|
||||
s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
|
||||
result = []
|
||||
in_quotes = False
|
||||
temp = ""
|
||||
|
||||
for idx,char in enumerate(s):
|
||||
if char == '"' and idx>155:
|
||||
temp += char
|
||||
if not in_quotes:
|
||||
result.append(temp)
|
||||
temp = ""
|
||||
|
||||
in_quotes = not in_quotes
|
||||
continue
|
||||
if in_quotes:
|
||||
if char.isspace():
|
||||
pass # have space token
|
||||
|
||||
result.append("“" + char + "”")
|
||||
else:
|
||||
temp += char
|
||||
|
||||
if temp:
|
||||
result.append(temp)
|
||||
|
||||
return result
|
||||
|
||||
for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
|
||||
|
||||
messages = [{"role": "user", "content": []}]
|
||||
|
||||
messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
|
||||
|
||||
messages[0]["content"].append({"type": "image", "image": imgs})
|
||||
|
||||
# 再添加 text
|
||||
messages[0]["content"].append({"type": "text", "text": f"{txt}"})
|
||||
|
||||
# Preparation for inference
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
|
||||
)
|
||||
|
||||
image_inputs = [imgs]
|
||||
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
old_inputs_ids = inputs.input_ids
|
||||
text_split_list = split_string(text)
|
||||
|
||||
token_list = []
|
||||
for text_each in text_split_list:
|
||||
txt_inputs = self.processor(
|
||||
text=text_each,
|
||||
images=None,
|
||||
videos=None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
token_each = txt_inputs.input_ids
|
||||
if token_each[0][0] == 2073 and token_each[0][-1] == 854:
|
||||
token_each = token_each[:, 1:-1]
|
||||
token_list.append(token_each)
|
||||
else:
|
||||
token_list.append(token_each)
|
||||
|
||||
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
|
||||
|
||||
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
||||
|
||||
idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
|
||||
idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
|
||||
inputs.input_ids = (
|
||||
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
||||
.unsqueeze(0)
|
||||
.to("cuda")
|
||||
)
|
||||
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
||||
outputs = self.model(
|
||||
input_ids=inputs.input_ids,
|
||||
attention_mask=inputs.attention_mask,
|
||||
pixel_values=inputs.pixel_values.to("cuda"),
|
||||
image_grid_thw=inputs.image_grid_thw.to("cuda"),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
emb = outputs["hidden_states"][-1]
|
||||
|
||||
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
|
||||
: self.max_length
|
||||
]
|
||||
|
||||
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
||||
(min(self.max_length, emb.shape[1] - 217)),
|
||||
dtype=torch.long,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
return embs, masks
|
||||
@@ -50,30 +50,14 @@ class PatchEmbed(torch.nn.Module):
|
||||
return latent + pos_embed
|
||||
|
||||
|
||||
class DiffusersCompatibleTimestepProj(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.linear_1 = torch.nn.Linear(dim_in, dim_out)
|
||||
self.act = torch.nn.SiLU()
|
||||
self.linear_2 = torch.nn.Linear(dim_out, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = self.act(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None):
|
||||
super().__init__()
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
|
||||
if diffusers_compatible_format:
|
||||
self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
|
||||
else:
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
time_emb = self.time_proj(timestep).to(dtype)
|
||||
|
||||
@@ -1,683 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch, math
|
||||
import torch.nn
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
|
||||
def attention(q, k, v, attn_mask, mode="torch"):
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = rearrange(x, "b n s d -> b s (n d)")
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_channels=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.0,
|
||||
use_conv=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_channels
|
||||
hidden_channels = hidden_channels or in_channels
|
||||
bias = (bias, bias)
|
||||
drop_probs = (drop, drop)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||
|
||||
self.fc1 = linear_layer(
|
||||
in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
|
||||
)
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = (
|
||||
norm_layer(hidden_channels, device=device, dtype=dtype)
|
||||
if norm_layer is not None
|
||||
else nn.Identity()
|
||||
)
|
||||
self.fc2 = linear_layer(
|
||||
hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
|
||||
)
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TextProjection(nn.Module):
|
||||
"""
|
||||
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(
|
||||
in_features=in_channels,
|
||||
out_features=hidden_size,
|
||||
bias=True,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.act_1 = act_layer()
|
||||
self.linear_2 = nn.Linear(
|
||||
in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=True,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
act_layer,
|
||||
frequency_embedding_size=256,
|
||||
max_period=10000,
|
||||
out_size=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.max_period = max_period
|
||||
if out_size is None:
|
||||
out_size = hidden_size
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(
|
||||
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
|
||||
),
|
||||
act_layer(),
|
||||
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
||||
)
|
||||
nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
|
||||
nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
Args:
|
||||
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||
dim (int): the dimension of the output.
|
||||
max_period (int): controls the minimum frequency of the embeddings.
|
||||
|
||||
Returns:
|
||||
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
||||
|
||||
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(
|
||||
t, self.frequency_embedding_size, self.max_period
|
||||
).type(t.dtype) # type: ignore
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
def apply_gate(x, gate=None, tanh=False):
|
||||
"""AI is creating summary for apply_gate
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input tensor.
|
||||
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
||||
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the output tensor after apply gate.
|
||||
"""
|
||||
if gate is None:
|
||||
return x
|
||||
if tanh:
|
||||
return x * gate.unsqueeze(1).tanh()
|
||||
else:
|
||||
return x * gate.unsqueeze(1)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine=True,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if hasattr(self, "weight"):
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
|
||||
def get_norm_layer(norm_layer):
|
||||
"""
|
||||
Get the normalization layer.
|
||||
|
||||
Args:
|
||||
norm_layer (str): The type of normalization layer.
|
||||
|
||||
Returns:
|
||||
norm_layer (nn.Module): The normalization layer.
|
||||
"""
|
||||
if norm_layer == "layer":
|
||||
return nn.LayerNorm
|
||||
elif norm_layer == "rms":
|
||||
return RMSNorm
|
||||
else:
|
||||
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
||||
|
||||
|
||||
def get_activation_layer(act_type):
|
||||
"""get activation layer
|
||||
|
||||
Args:
|
||||
act_type (str): the activation type
|
||||
|
||||
Returns:
|
||||
torch.nn.functional: the activation layer
|
||||
"""
|
||||
if act_type == "gelu":
|
||||
return lambda: nn.GELU()
|
||||
elif act_type == "gelu_tanh":
|
||||
return lambda: nn.GELU(approximate="tanh")
|
||||
elif act_type == "relu":
|
||||
return nn.ReLU
|
||||
elif act_type == "silu":
|
||||
return nn.SiLU
|
||||
else:
|
||||
raise ValueError(f"Unknown activation type: {act_type}")
|
||||
|
||||
class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
heads_num,
|
||||
mlp_width_ratio: str = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
act_type: str = "silu",
|
||||
qk_norm: bool = False,
|
||||
qk_norm_type: str = "layer",
|
||||
qkv_bias: bool = True,
|
||||
need_CA: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.need_CA = need_CA
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||
|
||||
self.norm1 = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||
)
|
||||
self.self_attn_qkv = nn.Linear(
|
||||
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
||||
)
|
||||
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||
self.self_attn_q_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||
if qk_norm
|
||||
else nn.Identity()
|
||||
)
|
||||
self.self_attn_k_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||
if qk_norm
|
||||
else nn.Identity()
|
||||
)
|
||||
self.self_attn_proj = nn.Linear(
|
||||
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||
)
|
||||
act_layer = get_activation_layer(act_type)
|
||||
self.mlp = MLP(
|
||||
in_channels=hidden_size,
|
||||
hidden_channels=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=mlp_drop_rate,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
act_layer(),
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
||||
)
|
||||
|
||||
if self.need_CA:
|
||||
self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
|
||||
heads_num=heads_num,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
act_type=act_type,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
qkv_bias=qkv_bias,
|
||||
**factory_kwargs,)
|
||||
# Zero-initialize the modulation
|
||||
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
||||
attn_mask: torch.Tensor = None,
|
||||
y: torch.Tensor = None,
|
||||
):
|
||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn_qkv(norm_x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
# Apply QK-Norm if needed
|
||||
q = self.self_attn_q_norm(q).to(v)
|
||||
k = self.self_attn_k_norm(k).to(v)
|
||||
|
||||
# Self-Attention
|
||||
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
||||
|
||||
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
||||
|
||||
if self.need_CA:
|
||||
x = self.cross_attnblock(x, c, attn_mask, y)
|
||||
|
||||
# FFN Layer
|
||||
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class CrossAttnBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
heads_num,
|
||||
mlp_width_ratio: str = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
act_type: str = "silu",
|
||||
qk_norm: bool = False,
|
||||
qk_norm_type: str = "layer",
|
||||
qkv_bias: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
|
||||
self.norm1 = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||
)
|
||||
self.norm1_2 = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||
)
|
||||
self.self_attn_q = nn.Linear(
|
||||
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
||||
)
|
||||
self.self_attn_kv = nn.Linear(
|
||||
hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
|
||||
)
|
||||
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||
self.self_attn_q_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||
if qk_norm
|
||||
else nn.Identity()
|
||||
)
|
||||
self.self_attn_k_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||
if qk_norm
|
||||
else nn.Identity()
|
||||
)
|
||||
self.self_attn_proj = nn.Linear(
|
||||
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||
)
|
||||
act_layer = get_activation_layer(act_type)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
act_layer(),
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
||||
)
|
||||
# Zero-initialize the modulation
|
||||
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
||||
attn_mask: torch.Tensor = None,
|
||||
y: torch.Tensor=None,
|
||||
|
||||
):
|
||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
norm_y = self.norm1_2(y)
|
||||
q = self.self_attn_q(norm_x)
|
||||
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
|
||||
kv = self.self_attn_kv(norm_y)
|
||||
k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
|
||||
# Apply QK-Norm if needed
|
||||
q = self.self_attn_q_norm(q).to(v)
|
||||
k = self.self_attn_k_norm(k).to(v)
|
||||
|
||||
# Self-Attention
|
||||
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
||||
|
||||
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class IndividualTokenRefiner(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
heads_num,
|
||||
depth,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
act_type: str = "silu",
|
||||
qk_norm: bool = False,
|
||||
qk_norm_type: str = "layer",
|
||||
qkv_bias: bool = True,
|
||||
need_CA:bool=False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.need_CA = need_CA
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
IndividualTokenRefinerBlock(
|
||||
hidden_size=hidden_size,
|
||||
heads_num=heads_num,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
act_type=act_type,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
qkv_bias=qkv_bias,
|
||||
need_CA=self.need_CA,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.LongTensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
y:torch.Tensor=None,
|
||||
):
|
||||
self_attn_mask = None
|
||||
if mask is not None:
|
||||
batch_size = mask.shape[0]
|
||||
seq_len = mask.shape[1]
|
||||
mask = mask.to(x.device)
|
||||
# batch_size x 1 x seq_len x seq_len
|
||||
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
|
||||
1, 1, seq_len, 1
|
||||
)
|
||||
# batch_size x 1 x seq_len x seq_len
|
||||
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
||||
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
|
||||
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
||||
# avoids self-attention weight being NaN for padding tokens
|
||||
self_attn_mask[:, :, :, 0] = True
|
||||
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, self_attn_mask,y)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SingleTokenRefiner(torch.nn.Module):
|
||||
"""
|
||||
A single token refiner block for llm text embedding refine.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
heads_num,
|
||||
depth,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
act_type: str = "silu",
|
||||
qk_norm: bool = False,
|
||||
qk_norm_type: str = "layer",
|
||||
qkv_bias: bool = True,
|
||||
need_CA:bool=False,
|
||||
attn_mode: str = "torch",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.attn_mode = attn_mode
|
||||
self.need_CA = need_CA
|
||||
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
|
||||
|
||||
self.input_embedder = nn.Linear(
|
||||
in_channels, hidden_size, bias=True, **factory_kwargs
|
||||
)
|
||||
if self.need_CA:
|
||||
self.input_embedder_CA = nn.Linear(
|
||||
in_channels, hidden_size, bias=True, **factory_kwargs
|
||||
)
|
||||
|
||||
act_layer = get_activation_layer(act_type)
|
||||
# Build timestep embedding layer
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
||||
# Build context embedding layer
|
||||
self.c_embedder = TextProjection(
|
||||
in_channels, hidden_size, act_layer, **factory_kwargs
|
||||
)
|
||||
|
||||
self.individual_token_refiner = IndividualTokenRefiner(
|
||||
hidden_size=hidden_size,
|
||||
heads_num=heads_num,
|
||||
depth=depth,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
act_type=act_type,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
qkv_bias=qkv_bias,
|
||||
need_CA=need_CA,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.LongTensor,
|
||||
mask: Optional[torch.LongTensor] = None,
|
||||
y: torch.LongTensor=None,
|
||||
):
|
||||
timestep_aware_representations = self.t_embedder(t)
|
||||
|
||||
if mask is None:
|
||||
context_aware_representations = x.mean(dim=1)
|
||||
else:
|
||||
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
||||
context_aware_representations = (x * mask_float).sum(
|
||||
dim=1
|
||||
) / mask_float.sum(dim=1)
|
||||
context_aware_representations = self.c_embedder(context_aware_representations)
|
||||
c = timestep_aware_representations + context_aware_representations
|
||||
|
||||
x = self.input_embedder(x)
|
||||
if self.need_CA:
|
||||
y = self.input_embedder_CA(y)
|
||||
x = self.individual_token_refiner(x, c, mask, y)
|
||||
else:
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Qwen2Connector(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
# biclip_dim=1024,
|
||||
in_channels=3584,
|
||||
hidden_size=4096,
|
||||
heads_num=32,
|
||||
depth=2,
|
||||
need_CA=False,
|
||||
device=None,
|
||||
dtype=torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
factory_kwargs = {"device": device, "dtype":dtype}
|
||||
|
||||
self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
|
||||
self.global_proj_out=nn.Linear(in_channels,768)
|
||||
|
||||
self.scale_factor = nn.Parameter(torch.zeros(1))
|
||||
with torch.no_grad():
|
||||
self.scale_factor.data += -(1 - 0.09)
|
||||
|
||||
def forward(self, x,t,mask):
|
||||
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
||||
x_mean = (x * mask_float).sum(
|
||||
dim=1
|
||||
) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(dtype=x.dtype, device=x.device))
|
||||
|
||||
global_out=self.global_proj_out(x_mean)
|
||||
encoder_hidden_states = self.S(x,t,mask)
|
||||
return encoder_hidden_states,global_out
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return Qwen2ConnectorStateDictConverter()
|
||||
|
||||
|
||||
class Qwen2ConnectorStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("connector."):
|
||||
name_ = name[len("connector."):]
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
@@ -45,7 +45,6 @@ def get_timestep_embedding(
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
computation_device = None,
|
||||
align_dtype_to_timestep = False,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
@@ -64,8 +63,6 @@ def get_timestep_embedding(
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent).to(timesteps.device)
|
||||
if align_dtype_to_timestep:
|
||||
emb = emb.to(timesteps.dtype)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
@@ -85,14 +82,12 @@ def get_timestep_embedding(
|
||||
|
||||
|
||||
class TemporalTimesteps(torch.nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.computation_device = computation_device
|
||||
self.scale = scale
|
||||
self.align_dtype_to_timestep = align_dtype_to_timestep
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
@@ -101,8 +96,6 @@ class TemporalTimesteps(torch.nn.Module):
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
computation_device=self.computation_device,
|
||||
scale=self.scale,
|
||||
align_dtype_to_timestep=self.align_dtype_to_timestep,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
@@ -62,16 +62,16 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
@@ -79,8 +79,8 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
import os
|
||||
from typing_extensions import Literal
|
||||
|
||||
class SimpleAdapter(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
|
||||
super(SimpleAdapter, self).__init__()
|
||||
|
||||
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
|
||||
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
|
||||
|
||||
# Convolution: reduce spatial dimensions by a factor
|
||||
# of 2 (without overlap)
|
||||
self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
|
||||
|
||||
# Residual blocks for feature extraction
|
||||
self.residual_blocks = nn.Sequential(
|
||||
*[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Reshape to merge the frame dimension into batch
|
||||
bs, c, f, h, w = x.size()
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
|
||||
|
||||
# Pixel Unshuffle operation
|
||||
x_unshuffled = self.pixel_unshuffle(x)
|
||||
|
||||
# Convolution operation
|
||||
x_conv = self.conv(x_unshuffled)
|
||||
|
||||
# Feature extraction with residual blocks
|
||||
out = self.residual_blocks(x_conv)
|
||||
|
||||
# Reshape to restore original bf dimension
|
||||
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
|
||||
|
||||
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return out
|
||||
|
||||
def process_camera_coordinates(
|
||||
self,
|
||||
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
|
||||
length: int,
|
||||
height: int,
|
||||
width: int,
|
||||
speed: float = 1/54,
|
||||
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
):
|
||||
if origin is None:
|
||||
origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
coordinates = generate_camera_coordinates(direction, length, speed, origin)
|
||||
plucker_embedding = process_pose_file(coordinates, width, height)
|
||||
return plucker_embedding
|
||||
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.relu(self.conv1(x))
|
||||
out = self.conv2(out)
|
||||
out += residual
|
||||
return out
|
||||
|
||||
class Camera(object):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
def __init__(self, entry):
|
||||
fx, fy, cx, cy = entry[1:5]
|
||||
self.fx = fx
|
||||
self.fy = fy
|
||||
self.cx = cx
|
||||
self.cy = cy
|
||||
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
||||
w2c_mat_4x4 = np.eye(4)
|
||||
w2c_mat_4x4[:3, :] = w2c_mat
|
||||
self.w2c_mat = w2c_mat_4x4
|
||||
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
||||
|
||||
def get_relative_pose(cam_params):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
||||
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
||||
cam_to_origin = 0
|
||||
target_cam_c2w = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, -cam_to_origin],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
||||
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
||||
ret_poses = np.array(ret_poses, dtype=np.float32)
|
||||
return ret_poses
|
||||
|
||||
def custom_meshgrid(*args):
|
||||
# torch>=2.0.0 only
|
||||
return torch.meshgrid(*args, indexing='ij')
|
||||
|
||||
|
||||
def ray_condition(K, c2w, H, W, device):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
# c2w: B, V, 4, 4
|
||||
# K: B, V, 4
|
||||
|
||||
B = K.shape[0]
|
||||
|
||||
j, i = custom_meshgrid(
|
||||
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
||||
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
||||
)
|
||||
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
|
||||
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
||||
|
||||
zs = torch.ones_like(i) # [B, HxW]
|
||||
xs = (i - cx) / fx * zs
|
||||
ys = (j - cy) / fy * zs
|
||||
zs = zs.expand_as(ys)
|
||||
|
||||
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
||||
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
||||
|
||||
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
||||
rays_o = c2w[..., :3, 3] # B, V, 3
|
||||
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
||||
# c2w @ dirctions
|
||||
rays_dxo = torch.linalg.cross(rays_o, rays_d)
|
||||
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
||||
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
||||
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
||||
return plucker
|
||||
|
||||
|
||||
def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
|
||||
if return_poses:
|
||||
return cam_params
|
||||
else:
|
||||
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
||||
|
||||
sample_wh_ratio = width / height
|
||||
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
||||
|
||||
if pose_wh_ratio > sample_wh_ratio:
|
||||
resized_ori_w = height * pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fx = resized_ori_w * cam_param.fx / width
|
||||
else:
|
||||
resized_ori_h = width / pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fy = resized_ori_h * cam_param.fy / height
|
||||
|
||||
intrinsic = np.asarray([[cam_param.fx * width,
|
||||
cam_param.fy * height,
|
||||
cam_param.cx * width,
|
||||
cam_param.cy * height]
|
||||
for cam_param in cam_params], dtype=np.float32)
|
||||
|
||||
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
||||
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
||||
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
||||
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
||||
plucker_embedding = plucker_embedding[None]
|
||||
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
||||
return plucker_embedding
|
||||
|
||||
|
||||
|
||||
def generate_camera_coordinates(
|
||||
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
|
||||
length: int,
|
||||
speed: float = 1/54,
|
||||
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
):
|
||||
coordinates = [list(origin)]
|
||||
while len(coordinates) < length:
|
||||
coor = coordinates[-1].copy()
|
||||
if "Left" in direction:
|
||||
coor[9] += speed
|
||||
if "Right" in direction:
|
||||
coor[9] -= speed
|
||||
if "Up" in direction:
|
||||
coor[13] += speed
|
||||
if "Down" in direction:
|
||||
coor[13] -= speed
|
||||
coordinates.append(coor)
|
||||
return coordinates
|
||||
@@ -5,7 +5,6 @@ import math
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .utils import hash_state_dict_keys
|
||||
from .wan_video_camera_controller import SimpleAdapter
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
@@ -37,8 +36,6 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
||||
x = flash_attn_interface.flash_attn_func(q, k, v)
|
||||
if isinstance(x,tuple):
|
||||
x = x[0]
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
elif FLASH_ATTN_2_AVAILABLE:
|
||||
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
||||
@@ -212,16 +209,9 @@ class DiTBlock(nn.Module):
|
||||
self.gate = GateModule()
|
||||
|
||||
def forward(self, x, context, t_mod, freqs):
|
||||
has_seq = len(t_mod.shape) == 4
|
||||
chunk_dim = 2 if has_seq else 1
|
||||
# msa: multi-head self-attention mlp: multi-layer perceptron
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim)
|
||||
if has_seq:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
|
||||
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
|
||||
)
|
||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
@@ -231,7 +221,7 @@ class DiTBlock(nn.Module):
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, has_pos_emb=False):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Sequential(
|
||||
nn.LayerNorm(in_dim),
|
||||
@@ -240,13 +230,8 @@ class MLP(torch.nn.Module):
|
||||
nn.Linear(in_dim, out_dim),
|
||||
nn.LayerNorm(out_dim)
|
||||
)
|
||||
self.has_pos_emb = has_pos_emb
|
||||
if has_pos_emb:
|
||||
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
|
||||
|
||||
def forward(self, x):
|
||||
if self.has_pos_emb:
|
||||
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
@@ -260,12 +245,8 @@ class Head(nn.Module):
|
||||
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||
|
||||
def forward(self, x, t_mod):
|
||||
if len(t_mod.shape) == 3:
|
||||
shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2)
|
||||
x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)))
|
||||
else:
|
||||
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
||||
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
||||
return x
|
||||
|
||||
|
||||
@@ -283,24 +264,12 @@ class WanModel(torch.nn.Module):
|
||||
num_heads: int,
|
||||
num_layers: int,
|
||||
has_image_input: bool,
|
||||
has_image_pos_emb: bool = False,
|
||||
has_ref_conv: bool = False,
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
seperated_timestep: bool = False,
|
||||
require_vae_embedding: bool = True,
|
||||
require_clip_embedding: bool = True,
|
||||
fuse_vae_embedding_in_latents: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.freq_dim = freq_dim
|
||||
self.has_image_input = has_image_input
|
||||
self.patch_size = patch_size
|
||||
self.seperated_timestep = seperated_timestep
|
||||
self.require_vae_embedding = require_vae_embedding
|
||||
self.require_clip_embedding = require_clip_embedding
|
||||
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
@@ -325,22 +294,10 @@ class WanModel(torch.nn.Module):
|
||||
self.freqs = precompute_freqs_cis_3d(head_dim)
|
||||
|
||||
if has_image_input:
|
||||
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
|
||||
if has_ref_conv:
|
||||
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
|
||||
self.has_image_pos_emb = has_image_pos_emb
|
||||
self.has_ref_conv = has_ref_conv
|
||||
if add_control_adapter:
|
||||
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
|
||||
else:
|
||||
self.control_adapter = None
|
||||
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
|
||||
|
||||
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
|
||||
def patchify(self, x: torch.Tensor):
|
||||
x = self.patch_embedding(x)
|
||||
if self.control_adapter is not None and control_camera_latents_input is not None:
|
||||
y_camera = self.control_adapter(control_camera_latents_input)
|
||||
x = [u + v for u, v in zip(x, y_camera)]
|
||||
x = x[0].unsqueeze(0)
|
||||
grid_size = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
return x, grid_size # x, grid_size: (f, h, w)
|
||||
@@ -494,7 +451,6 @@ class WanModelStateDictConverter:
|
||||
return state_dict_, config
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
@@ -537,182 +493,6 @@ class WanModelStateDictConverter:
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||
# 1.3B PAI control
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
||||
# 14B PAI control
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_image_pos_emb": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
|
||||
# 1.3B PAI control v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
|
||||
# 14B PAI control v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
|
||||
# 1.3B PAI control-camera v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 32,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
|
||||
# 14B PAI control-camera v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 32,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316":
|
||||
# Wan-AI/Wan2.2-TI2V-5B
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 3072,
|
||||
"ffn_dim": 14336,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 48,
|
||||
"num_heads": 24,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"seperated_timestep": True,
|
||||
"require_clip_embedding": False,
|
||||
"require_vae_embedding": False,
|
||||
"fuse_vae_embedding_in_latents": True,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
|
||||
# Wan-AI/Wan2.2-I2V-A14B
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .wan_video_dit import sinusoidal_embedding_1d
|
||||
|
||||
|
||||
|
||||
class WanMotionControllerModel(torch.nn.Module):
|
||||
def __init__(self, freq_dim=256, dim=1536):
|
||||
super().__init__()
|
||||
self.freq_dim = freq_dim
|
||||
self.linear = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 6),
|
||||
)
|
||||
|
||||
def forward(self, motion_bucket_id):
|
||||
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
|
||||
emb = self.linear(emb)
|
||||
return emb
|
||||
|
||||
def init(self):
|
||||
state_dict = self.linear[-1].state_dict()
|
||||
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
||||
self.linear[-1].load_state_dict(state_dict)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanMotionControllerModelDictConverter()
|
||||
|
||||
|
||||
|
||||
class WanMotionControllerModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
import torch
|
||||
from .wan_video_dit import DiTBlock
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
class VaceWanAttentionBlock(DiTBlock):
|
||||
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
||||
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = torch.nn.Linear(self.dim, self.dim)
|
||||
self.after_proj = torch.nn.Linear(self.dim, self.dim)
|
||||
|
||||
def forward(self, c, x, context, t_mod, freqs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c))
|
||||
c = all_c.pop(-1)
|
||||
c = super().forward(c, context, t_mod, freqs)
|
||||
c_skip = self.after_proj(c)
|
||||
all_c += [c_skip, c]
|
||||
c = torch.stack(all_c)
|
||||
return c
|
||||
|
||||
|
||||
class VaceWanModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
|
||||
vace_in_dim=96,
|
||||
patch_size=(1, 2, 2),
|
||||
has_image_input=False,
|
||||
dim=1536,
|
||||
num_heads=12,
|
||||
ffn_dim=8960,
|
||||
eps=1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.vace_layers = vace_layers
|
||||
self.vace_in_dim = vace_in_dim
|
||||
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
||||
|
||||
# vace blocks
|
||||
self.vace_blocks = torch.nn.ModuleList([
|
||||
VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
|
||||
for i in self.vace_layers
|
||||
])
|
||||
|
||||
# vace patch embeddings
|
||||
self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(
|
||||
self, x, vace_context, context, t_mod, freqs,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
):
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.vace_blocks:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
c = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
c, x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
c = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
c, x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
c = block(c, x, context, t_mod, freqs)
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return VaceWanModelDictConverter()
|
||||
|
||||
|
||||
class VaceWanModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("vace")}
|
||||
if hash_state_dict_keys(state_dict_) == '3b2726384e4f64837bdf216eea3f310d': # vace 14B
|
||||
config = {
|
||||
"vace_layers": (0, 5, 10, 15, 20, 25, 30, 35),
|
||||
"vace_in_dim": 96,
|
||||
"patch_size": (1, 2, 2),
|
||||
"has_image_input": False,
|
||||
"dim": 5120,
|
||||
"num_heads": 40,
|
||||
"ffn_dim": 13824,
|
||||
"eps": 1e-06,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict_, config
|
||||
@@ -195,75 +195,6 @@ class Resample(nn.Module):
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
|
||||
|
||||
def patchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(x,
|
||||
"b c f (h q) (w r) -> b (c r q) f h w",
|
||||
q=patch_size,
|
||||
r=patch_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(x,
|
||||
"b (c r q) f h w -> b c f (h q) (w r)",
|
||||
q=patch_size,
|
||||
r=patch_size)
|
||||
return x
|
||||
|
||||
|
||||
class Resample38(Resample):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in (
|
||||
"none",
|
||||
"upsample2d",
|
||||
"upsample3d",
|
||||
"downsample2d",
|
||||
"downsample3d",
|
||||
)
|
||||
super(Resample, self).__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == "upsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||
nn.Conv2d(dim, dim, 3, padding=1),
|
||||
)
|
||||
elif mode == "upsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||
nn.Conv2d(dim, dim, 3, padding=1),
|
||||
)
|
||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
elif mode == "downsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
|
||||
)
|
||||
elif mode == "downsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
|
||||
)
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
||||
)
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
@@ -342,178 +273,6 @@ class AttentionBlock(nn.Module):
|
||||
return x + identity
|
||||
|
||||
|
||||
class AvgDown3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
factor_t,
|
||||
factor_s=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||
|
||||
assert in_channels * self.factor % out_channels == 0
|
||||
self.group_size = in_channels * self.factor // out_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
||||
pad = (0, 0, 0, 0, pad_t, 0)
|
||||
x = F.pad(x, pad)
|
||||
B, C, T, H, W = x.shape
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
T // self.factor_t,
|
||||
self.factor_t,
|
||||
H // self.factor_s,
|
||||
self.factor_s,
|
||||
W // self.factor_s,
|
||||
self.factor_s,
|
||||
)
|
||||
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
||||
x = x.view(
|
||||
B,
|
||||
C * self.factor,
|
||||
T // self.factor_t,
|
||||
H // self.factor_s,
|
||||
W // self.factor_s,
|
||||
)
|
||||
x = x.view(
|
||||
B,
|
||||
self.out_channels,
|
||||
self.group_size,
|
||||
T // self.factor_t,
|
||||
H // self.factor_s,
|
||||
W // self.factor_s,
|
||||
)
|
||||
x = x.mean(dim=2)
|
||||
return x
|
||||
|
||||
|
||||
class DupUp3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
factor_t,
|
||||
factor_s=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||
|
||||
assert out_channels * self.factor % in_channels == 0
|
||||
self.repeats = out_channels * self.factor // in_channels
|
||||
|
||||
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
||||
x = x.repeat_interleave(self.repeats, dim=1)
|
||||
x = x.view(
|
||||
x.size(0),
|
||||
self.out_channels,
|
||||
self.factor_t,
|
||||
self.factor_s,
|
||||
self.factor_s,
|
||||
x.size(2),
|
||||
x.size(3),
|
||||
x.size(4),
|
||||
)
|
||||
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
||||
x = x.view(
|
||||
x.size(0),
|
||||
self.out_channels,
|
||||
x.size(2) * self.factor_t,
|
||||
x.size(4) * self.factor_s,
|
||||
x.size(6) * self.factor_s,
|
||||
)
|
||||
if first_chunk:
|
||||
x = x[:, :, self.factor_t - 1 :, :, :]
|
||||
return x
|
||||
|
||||
|
||||
class Down_ResidualBlock(nn.Module):
|
||||
def __init__(
|
||||
self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Shortcut path with downsample
|
||||
self.avg_shortcut = AvgDown3D(
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_downsample else 1,
|
||||
factor_s=2 if down_flag else 1,
|
||||
)
|
||||
|
||||
# Main path with residual blocks and downsample
|
||||
downsamples = []
|
||||
for _ in range(mult):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
in_dim = out_dim
|
||||
|
||||
# Add the final downsample block
|
||||
if down_flag:
|
||||
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||
downsamples.append(Resample38(out_dim, mode=mode))
|
||||
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
x_copy = x.clone()
|
||||
for module in self.downsamples:
|
||||
x = module(x, feat_cache, feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy)
|
||||
|
||||
|
||||
class Up_ResidualBlock(nn.Module):
|
||||
def __init__(
|
||||
self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False
|
||||
):
|
||||
super().__init__()
|
||||
# Shortcut path with upsample
|
||||
if up_flag:
|
||||
self.avg_shortcut = DupUp3D(
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_upsample else 1,
|
||||
factor_s=2 if up_flag else 1,
|
||||
)
|
||||
else:
|
||||
self.avg_shortcut = None
|
||||
|
||||
# Main path with residual blocks and upsample
|
||||
upsamples = []
|
||||
for _ in range(mult):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
in_dim = out_dim
|
||||
|
||||
# Add the final upsample block
|
||||
if up_flag:
|
||||
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
||||
upsamples.append(Resample38(out_dim, mode=mode))
|
||||
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
x_main = x.clone()
|
||||
for module in self.upsamples:
|
||||
x_main = module(x_main, feat_cache, feat_idx)
|
||||
if self.avg_shortcut is not None:
|
||||
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||
return x_main + x_shortcut
|
||||
else:
|
||||
return x_main
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@@ -617,122 +376,6 @@ class Encoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Encoder3d_38(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
t_down_flag = (
|
||||
temperal_downsample[i] if i < len(temperal_downsample) else False
|
||||
)
|
||||
downsamples.append(
|
||||
Down_ResidualBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
dropout=dropout,
|
||||
mult=num_res_blocks,
|
||||
temperal_downsample=t_down_flag,
|
||||
down_flag=i != len(dim_mult) - 1,
|
||||
)
|
||||
)
|
||||
scale /= 2.0
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(out_dim, out_dim, dropout),
|
||||
AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout),
|
||||
)
|
||||
|
||||
# # output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :]
|
||||
.unsqueeze(2)
|
||||
.to(cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@@ -838,112 +481,10 @@ class Decoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class Decoder3d_38(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout))
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
|
||||
upsamples.append(
|
||||
Up_ResidualBlock(in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
dropout=dropout,
|
||||
mult=num_res_blocks + 1,
|
||||
temperal_upsample=t_up_flag,
|
||||
up_flag=i != len(dim_mult) - 1))
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 12, 3, padding=1))
|
||||
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, first_chunk)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :]
|
||||
.unsqueeze(2)
|
||||
.to(cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
if check_is_instance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@@ -1075,7 +616,6 @@ class WanVideoVAE(nn.Module):
|
||||
# init model
|
||||
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
|
||||
self.upsampling_factor = 8
|
||||
self.z_dim = z_dim
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
@@ -1171,7 +711,7 @@ class WanVideoVAE(nn.Module):
|
||||
|
||||
out_T = (T + 3) // 4
|
||||
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
values = torch.zeros((1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
|
||||
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
||||
@@ -1222,8 +762,8 @@ class WanVideoVAE(nn.Module):
|
||||
for video in videos:
|
||||
video = video.unsqueeze(0)
|
||||
if tiled:
|
||||
tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor)
|
||||
tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor)
|
||||
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
|
||||
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
|
||||
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
||||
else:
|
||||
hidden_state = self.single_encode(video, device)
|
||||
@@ -1234,11 +774,18 @@ class WanVideoVAE(nn.Module):
|
||||
|
||||
|
||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_states, device)
|
||||
return video
|
||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||
videos = []
|
||||
for hidden_state in hidden_states:
|
||||
hidden_state = hidden_state.unsqueeze(0)
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_state, device)
|
||||
video = video.squeeze(0)
|
||||
videos.append(video)
|
||||
videos = torch.stack(videos)
|
||||
return videos
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -1258,119 +805,3 @@ class WanVideoVAEStateDictConverter:
|
||||
for name in state_dict:
|
||||
state_dict_['model.' + name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
class VideoVAE38_(VideoVAE_):
|
||||
|
||||
def __init__(self,
|
||||
dim=160,
|
||||
z_dim=48,
|
||||
dec_dim=256,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super(VideoVAE_, self).__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d_38(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d_38(dec_dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
|
||||
def encode(self, x, scale):
|
||||
self.clear_cache()
|
||||
x = patchify(x, patch_size=2)
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
self.clear_cache()
|
||||
return mu
|
||||
|
||||
|
||||
def decode(self, z, scale):
|
||||
self.clear_cache()
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=z.dtype, device=z.device)
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
first_chunk=True)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
out = unpatchify(out, patch_size=2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
|
||||
class WanVideoVAE38(WanVideoVAE):
|
||||
|
||||
def __init__(self, z_dim=48, dim=160):
|
||||
super(WanVideoVAE, self).__init__()
|
||||
|
||||
mean = [
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
||||
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
||||
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
||||
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667
|
||||
]
|
||||
std = [
|
||||
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
||||
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
||||
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
||||
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
||||
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
||||
]
|
||||
self.mean = torch.tensor(mean)
|
||||
self.std = torch.tensor(std)
|
||||
self.scale = [self.mean, 1.0 / self.std]
|
||||
|
||||
# init model
|
||||
self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False)
|
||||
self.upsampling_factor = 16
|
||||
self.z_dim = z_dim
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompters import FluxPrompter
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
@@ -33,112 +32,105 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.ipadapter: FluxIpAdapter = None
|
||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
||||
self.infinityou_processor: InfinitYou = None
|
||||
self.qwenvl = None
|
||||
self.step1x_connector: Qwen2Connector = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'qwenvl', 'step1x_connector']
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
if self.text_encoder_1 is not None:
|
||||
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_1,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
if self.text_encoder_2 is not None:
|
||||
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_2,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
T5LayerNorm: AutoWrappedModule,
|
||||
T5DenseActDense: AutoWrappedModule,
|
||||
T5DenseGatedActDense: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
if self.dit is not None:
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cuda",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
if self.vae_decoder is not None:
|
||||
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_decoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
if self.vae_encoder is not None:
|
||||
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_encoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_1,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_2,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
T5LayerNorm: AutoWrappedModule,
|
||||
T5DenseActDense: AutoWrappedModule,
|
||||
T5DenseGatedActDense: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cuda",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_decoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_encoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
|
||||
@@ -175,10 +167,6 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
|
||||
if self.image_proj_model is not None:
|
||||
self.infinityou_processor = InfinitYou(device=self.device)
|
||||
|
||||
# Step1x
|
||||
self.qwenvl = model_manager.fetch_model("qwenvl")
|
||||
self.step1x_connector = model_manager.fetch_model("step1x_connector")
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -203,13 +191,10 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
|
||||
if self.text_encoder_1 is not None and self.text_encoder_2 is not None:
|
||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
||||
else:
|
||||
return {}
|
||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None, guidance=1.0):
|
||||
@@ -375,46 +360,6 @@ class FluxImagePipeline(BasePipeline):
|
||||
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
||||
else:
|
||||
return {}, controlnet_image
|
||||
|
||||
|
||||
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, tiled=False, tile_size=64, tile_stride=32):
|
||||
if self.dit.input_dim == 196:
|
||||
if flex_inpaint_image is None:
|
||||
flex_inpaint_image = torch.zeros_like(latents)
|
||||
else:
|
||||
flex_inpaint_image = self.preprocess_image(flex_inpaint_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
flex_inpaint_image = self.encode_image(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if flex_inpaint_mask is None:
|
||||
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
|
||||
else:
|
||||
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
|
||||
flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype)
|
||||
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
|
||||
flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
|
||||
if flex_control_image is None:
|
||||
flex_control_image = torch.zeros_like(latents)
|
||||
else:
|
||||
flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
|
||||
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
|
||||
flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
|
||||
flex_control_stop_timestep = self.scheduler.timesteps[int(flex_control_stop * (len(self.scheduler.timesteps) - 1))]
|
||||
flex_kwargs = {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
|
||||
else:
|
||||
flex_kwargs = {}
|
||||
return flex_kwargs
|
||||
|
||||
|
||||
def prepare_step1x_kwargs(self, prompt, negative_prompt, image):
|
||||
if image is None:
|
||||
return {}, {}
|
||||
self.load_models_to_device(["qwenvl", "vae_encoder"])
|
||||
captions = [prompt, negative_prompt]
|
||||
ref_images = [image, image]
|
||||
embs, masks = self.qwenvl(captions, ref_images)
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
image = self.encode_image(image)
|
||||
return {"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image}, {"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -453,14 +398,6 @@ class FluxImagePipeline(BasePipeline):
|
||||
# InfiniteYou
|
||||
infinityou_id_image=None,
|
||||
infinityou_guidance=1.0,
|
||||
# Flex
|
||||
flex_inpaint_image=None,
|
||||
flex_inpaint_mask=None,
|
||||
flex_control_image=None,
|
||||
flex_control_strength=0.5,
|
||||
flex_control_stop=0.5,
|
||||
# Step1x
|
||||
step1x_reference_image=None,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
# Tile
|
||||
@@ -499,26 +436,20 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
# ControlNets
|
||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||
|
||||
# Flex
|
||||
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength=flex_control_strength, flex_control_stop=flex_control_stop, **tiler_kwargs)
|
||||
|
||||
# Step1x
|
||||
step1x_kwargs_posi, step1x_kwargs_nega = self.prepare_step1x_kwargs(prompt, negative_prompt, image=step1x_reference_image)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(['dit', 'controlnet', 'step1x_connector'])
|
||||
self.load_models_to_device(['dit', 'controlnet'])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Positive side
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector,
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_posi,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||
@@ -533,9 +464,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
if cfg_scale != 1.0:
|
||||
# Negative side
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector,
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_nega,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -655,7 +586,6 @@ class TeaCache:
|
||||
def lets_dance_flux(
|
||||
dit: FluxDiT,
|
||||
controlnet: FluxMultiControlNetManager = None,
|
||||
step1x_connector: Qwen2Connector = None,
|
||||
hidden_states=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
@@ -672,12 +602,6 @@ def lets_dance_flux(
|
||||
ipadapter_kwargs_list={},
|
||||
id_emb=None,
|
||||
infinityou_guidance=None,
|
||||
flex_condition=None,
|
||||
flex_uncondition=None,
|
||||
flex_control_stop_timestep=None,
|
||||
step1x_llm_embedding=None,
|
||||
step1x_mask=None,
|
||||
step1x_reference_latents=None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -728,18 +652,6 @@ def lets_dance_flux(
|
||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||
controlnet_frames, **controlnet_extra_kwargs
|
||||
)
|
||||
|
||||
# Flex
|
||||
if flex_condition is not None:
|
||||
if timestep.tolist()[0] >= flex_control_stop_timestep:
|
||||
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
|
||||
else:
|
||||
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
|
||||
|
||||
# Step1x
|
||||
if step1x_llm_embedding is not None:
|
||||
prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask)
|
||||
text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device)
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = dit.prepare_image_ids(hidden_states)
|
||||
@@ -751,18 +663,10 @@ def lets_dance_flux(
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = dit.patchify(hidden_states)
|
||||
|
||||
# Step1x
|
||||
if step1x_reference_latents is not None:
|
||||
step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents)
|
||||
step1x_reference_latents = dit.patchify(step1x_reference_latents)
|
||||
image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2)
|
||||
hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1)
|
||||
|
||||
hidden_states = dit.x_embedder(hidden_states)
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, 16)
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
else:
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
@@ -813,11 +717,6 @@ def lets_dance_flux(
|
||||
|
||||
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
||||
hidden_states = dit.final_proj_out(hidden_states)
|
||||
|
||||
# Step1x
|
||||
if step1x_reference_latents is not None:
|
||||
hidden_states = hidden_states[:, :hidden_states.shape[1] // 2]
|
||||
|
||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,493 +0,0 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter
|
||||
from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora import GeneralLoRALoader
|
||||
|
||||
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
|
||||
|
||||
|
||||
class QwenImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
|
||||
self.accelerate_adapter: QwenImageAccelerateAdapter = None
|
||||
self.text_encoder: QwenImageTextEncoder = None
|
||||
self.dit: QwenImageDiT = None
|
||||
self.vae: QwenImageVAE = None
|
||||
self.tokenizer: Qwen2Tokenizer = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("accelerate_adapter", "dit",)
|
||||
self.units = [
|
||||
QwenImageUnit_ShapeChecker(),
|
||||
QwenImageUnit_NoiseInitializer(),
|
||||
QwenImageUnit_InputImageEmbedder(),
|
||||
QwenImageUnit_PromptEmbedder(),
|
||||
QwenImageUnit_EntityControl(),
|
||||
]
|
||||
self.model_fn = model_fn_qwen_image
|
||||
|
||||
|
||||
def load_lora(self, module, path, alpha=1):
|
||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
|
||||
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
|
||||
|
||||
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
|
||||
self.vram_management_enabled = True
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
|
||||
if self.text_encoder is not None:
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
|
||||
dtype = next(iter(self.text_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
|
||||
Qwen2RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.dit is not None:
|
||||
from ..models.qwen_image_dit import RMSNorm
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
if not enable_dit_fp8_computation:
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
else:
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae is not None:
|
||||
from ..models.qwen_image_vae import QwenImageRMS_norm
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
QwenImageRMS_norm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
):
|
||||
# Download and load models
|
||||
model_manager = ModelManager()
|
||||
for model_config in model_configs:
|
||||
model_config.download_if_necessary()
|
||||
model_manager.load_model(
|
||||
model_config.path,
|
||||
device=model_config.offload_device or device,
|
||||
torch_dtype=model_config.offload_dtype or torch_dtype
|
||||
)
|
||||
|
||||
# Initialize pipeline
|
||||
pipe = QwenImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
|
||||
pipe.dit = model_manager.fetch_model("qwen_image_dit")
|
||||
pipe.vae = model_manager.fetch_model("qwen_image_vae")
|
||||
pipe.accelerate_adapter = model_manager.fetch_model("qwen_image_accelerate_adapter")
|
||||
if tokenizer_config is not None and pipe.text_encoder is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
from transformers import Qwen2Tokenizer
|
||||
pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path)
|
||||
return pipe
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 4.0,
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height: int = 1328,
|
||||
width: int = 1328,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# EliGen
|
||||
eligen_entity_prompts: list[str] = None,
|
||||
eligen_entity_masks: list[Image.Image] = None,
|
||||
eligen_enable_on_negative: bool = False,
|
||||
# FP8
|
||||
enable_fp8_attention: bool = False,
|
||||
# Tile
|
||||
tiled: bool = False,
|
||||
tile_size: int = 128,
|
||||
tile_stride: int = 64,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16))
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"enable_fp8_attention": enable_fp8_attention,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
|
||||
class QwenImageUnit_ShapeChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("height", "width"))
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, height, width):
|
||||
height, width = pipe.check_resize_height_width(height, width)
|
||||
return {"height": height, "width": width}
|
||||
|
||||
|
||||
|
||||
class QwenImageUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("height", "width", "seed", "rand_device"))
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
|
||||
class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": None}
|
||||
|
||||
|
||||
|
||||
class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
return split_result
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, prompt) -> dict:
|
||||
if pipe.text_encoder is not None:
|
||||
prompt = [prompt]
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 34
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
|
||||
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
|
||||
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
|
||||
prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class QwenImageUnit_EntityControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("text_encoder")
|
||||
)
|
||||
|
||||
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
return split_result
|
||||
|
||||
def get_prompt_emb(self, pipe: QwenImagePipeline, prompt) -> dict:
|
||||
if pipe.text_encoder is not None:
|
||||
prompt = [prompt]
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 34
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
|
||||
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
|
||||
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
|
||||
prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def preprocess_masks(self, pipe, masks, height, width, dim):
|
||||
out_masks = []
|
||||
for mask in masks:
|
||||
mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
|
||||
mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
out_masks.append(mask)
|
||||
return out_masks
|
||||
|
||||
def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height):
|
||||
entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1)
|
||||
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
|
||||
prompt_embs, prompt_emb_masks = [], []
|
||||
for entity_prompt in entity_prompts:
|
||||
prompt_emb_dict = self.get_prompt_emb(pipe, entity_prompt)
|
||||
prompt_embs.append(prompt_emb_dict['prompt_emb'])
|
||||
prompt_emb_masks.append(prompt_emb_dict['prompt_emb_mask'])
|
||||
return prompt_embs, prompt_emb_masks, entity_masks
|
||||
|
||||
def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, enable_eligen_on_negative, cfg_scale):
|
||||
entity_prompt_emb_posi, entity_prompt_emb_posi_mask, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height)
|
||||
if enable_eligen_on_negative and cfg_scale != 1.0:
|
||||
entity_prompt_emb_nega = [prompt_emb_nega['prompt_emb']] * len(entity_prompt_emb_posi)
|
||||
entity_prompt_emb_nega_mask = [prompt_emb_nega['prompt_emb_mask']] * len(entity_prompt_emb_posi)
|
||||
entity_masks_nega = entity_masks_posi
|
||||
else:
|
||||
entity_prompt_emb_nega, entity_prompt_emb_nega_mask, entity_masks_nega = None, None, None
|
||||
eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi, "entity_prompt_emb_mask": entity_prompt_emb_posi_mask}
|
||||
eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega, "entity_prompt_emb_mask": entity_prompt_emb_nega_mask}
|
||||
return eligen_kwargs_posi, eligen_kwargs_nega
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None)
|
||||
if eligen_entity_prompts is None or eligen_entity_masks is None or len(eligen_entity_prompts) == 0 or len(eligen_entity_masks) == 0:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False)
|
||||
eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,
|
||||
eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"],
|
||||
eligen_enable_on_negative, inputs_shared["cfg_scale"])
|
||||
inputs_posi.update(eligen_kwargs_posi)
|
||||
if inputs_shared.get("cfg_scale", 1.0) != 1.0:
|
||||
inputs_nega.update(eligen_kwargs_nega)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
def model_fn_qwen_image(
|
||||
dit: QwenImageDiT = None,
|
||||
accelerate_adapter: QwenImageAccelerateAdapter = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
prompt_emb_mask=None,
|
||||
height=None,
|
||||
width=None,
|
||||
entity_prompt_emb=None,
|
||||
entity_prompt_emb_mask=None,
|
||||
entity_masks=None,
|
||||
enable_fp8_attention=False,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs
|
||||
):
|
||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
timestep = timestep / 1000
|
||||
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
|
||||
image = dit.img_in(image)
|
||||
conditioning = dit.time_text_embed(timestep, image.dtype)
|
||||
|
||||
if entity_prompt_emb is not None:
|
||||
text, image_rotary_emb, attention_mask = dit.process_entity_masks(
|
||||
latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask,
|
||||
entity_masks, height, width, image, img_shapes,
|
||||
)
|
||||
else:
|
||||
text = dit.txt_in(dit.txt_norm(prompt_emb))
|
||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
attention_mask = None
|
||||
|
||||
for block in dit.transformer_blocks:
|
||||
text, image = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
enable_fp8_attention=enable_fp8_attention,
|
||||
)
|
||||
|
||||
if accelerate_adapter is not None:
|
||||
image = accelerate_adapter(latents, image, text, image_rotary_emb, timestep)
|
||||
else:
|
||||
image = dit.norm_out(image, conditioning)
|
||||
image = dit.proj_out(image)
|
||||
|
||||
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
return latents
|
||||
@@ -4,7 +4,6 @@ from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import WanPrompter
|
||||
@@ -19,7 +18,6 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra
|
||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
|
||||
|
||||
|
||||
@@ -33,9 +31,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.image_encoder: WanImageEncoder = None
|
||||
self.dit: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.vace: VaceWanModel = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace']
|
||||
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
|
||||
self.height_division_factor = 16
|
||||
self.width_division_factor = 16
|
||||
self.use_unified_sequence_parallel = False
|
||||
@@ -68,7 +64,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -127,40 +122,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
if self.motion_controller is not None:
|
||||
dtype = next(iter(self.motion_controller.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.motion_controller,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
if self.vace is not None:
|
||||
enable_vram_management(
|
||||
self.vace,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
|
||||
@@ -173,8 +134,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.dit = model_manager.fetch_model("wan_video_dit")
|
||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
||||
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
self.vace = model_manager.fetch_model("wan_video_vace")
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -204,62 +163,22 @@ class WanVideoPipeline(BasePipeline):
|
||||
return {"context": prompt_emb}
|
||||
|
||||
|
||||
def encode_image(self, image, end_image, num_frames, height, width, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
def encode_image(self, image, num_frames, height, width):
|
||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||
clip_context = self.image_encoder.encode_image([image])
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||
msk[:, 1:] = 0
|
||||
if end_image is not None:
|
||||
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
|
||||
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||
if self.dit.has_image_pos_emb:
|
||||
clip_context = torch.concat([clip_context, self.image_encoder.encode_image([end_image])], dim=1)
|
||||
msk[:, -1:] = 1
|
||||
else:
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=self.torch_dtype, device=self.device)
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
|
||||
y = y.to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"clip_feature": clip_context, "y": y}
|
||||
|
||||
|
||||
def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
control_video = self.preprocess_images(control_video)
|
||||
control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
return latents
|
||||
|
||||
|
||||
def prepare_reference_image(self, reference_image, height, width):
|
||||
if reference_image is not None:
|
||||
self.load_models_to_device(["vae"])
|
||||
reference_image = reference_image.resize((width, height))
|
||||
reference_image = self.preprocess_images([reference_image])
|
||||
reference_image = torch.stack(reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
reference_latents = self.vae.encode(reference_image, device=self.device)
|
||||
return {"reference_latents": reference_latents}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
if control_video is not None:
|
||||
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if clip_feature is None or y is None:
|
||||
clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
|
||||
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
y = y[:, -16:]
|
||||
y = torch.concat([control_latents, y], dim=1)
|
||||
return {"clip_feature": clip_feature, "y": y}
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
@@ -285,62 +204,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
def prepare_unified_sequence_parallel(self):
|
||||
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
||||
|
||||
|
||||
def prepare_motion_bucket_id(self, motion_bucket_id):
|
||||
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"motion_bucket_id": motion_bucket_id}
|
||||
|
||||
|
||||
def prepare_vace_kwargs(
|
||||
self,
|
||||
latents,
|
||||
vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0,
|
||||
height=480, width=832, num_frames=81,
|
||||
seed=None, rand_device="cpu",
|
||||
tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
|
||||
):
|
||||
if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
|
||||
self.load_models_to_device(["vae"])
|
||||
if vace_video is None:
|
||||
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
vace_video = self.preprocess_images(vace_video)
|
||||
vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
if vace_mask is None:
|
||||
vace_mask = torch.ones_like(vace_video)
|
||||
else:
|
||||
vace_mask = self.preprocess_images(vace_mask)
|
||||
vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
|
||||
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
|
||||
inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
vace_video_latents = torch.concat((inactive, reactive), dim=1)
|
||||
|
||||
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
|
||||
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
|
||||
|
||||
if vace_reference_image is None:
|
||||
pass
|
||||
else:
|
||||
vace_reference_image = self.preprocess_images([vace_reference_image])
|
||||
vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
|
||||
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
|
||||
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
|
||||
|
||||
noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32)
|
||||
noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = torch.concat((noise, latents), dim=2)
|
||||
|
||||
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
|
||||
return latents, {"vace_context": vace_context, "vace_scale": vace_scale}
|
||||
else:
|
||||
return latents, {"vace_context": None, "vace_scale": vace_scale}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -349,14 +212,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_image=None,
|
||||
end_image=None,
|
||||
input_video=None,
|
||||
control_video=None,
|
||||
reference_image=None,
|
||||
vace_video=None,
|
||||
vace_video_mask=None,
|
||||
vace_reference_image=None,
|
||||
vace_scale=1.0,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device="cpu",
|
||||
@@ -366,7 +222,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
cfg_scale=5.0,
|
||||
num_inference_steps=50,
|
||||
sigma_shift=5.0,
|
||||
motion_bucket_id=None,
|
||||
tiled=True,
|
||||
tile_size=(30, 52),
|
||||
tile_stride=(15, 26),
|
||||
@@ -379,7 +234,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
if num_frames % 4 != 1:
|
||||
num_frames = (num_frames + 2) // 4 * 4 + 1
|
||||
print(f"Only `num_frames % 4 == 1` is acceptable. We round it up to {num_frames}.")
|
||||
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
@@ -408,33 +263,13 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Encode image
|
||||
if input_image is not None and self.image_encoder is not None:
|
||||
self.load_models_to_device(["image_encoder", "vae"])
|
||||
image_emb = self.encode_image(input_image, end_image, num_frames, height, width, **tiler_kwargs)
|
||||
image_emb = self.encode_image(input_image, num_frames, height, width)
|
||||
else:
|
||||
image_emb = {}
|
||||
|
||||
# Reference image
|
||||
reference_image_kwargs = self.prepare_reference_image(reference_image, height, width)
|
||||
|
||||
# ControlNet
|
||||
if control_video is not None:
|
||||
self.load_models_to_device(["image_encoder", "vae"])
|
||||
image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
|
||||
|
||||
# Motion Controller
|
||||
if self.motion_controller is not None and motion_bucket_id is not None:
|
||||
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
|
||||
else:
|
||||
motion_kwargs = {}
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# VACE
|
||||
latents, vace_kwargs = self.prepare_vace_kwargs(
|
||||
latents, vace_video, vace_video_mask, vace_reference_image, vace_scale,
|
||||
height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs
|
||||
)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
@@ -443,33 +278,20 @@ class WanVideoPipeline(BasePipeline):
|
||||
usp_kwargs = self.prepare_unified_sequence_parallel()
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(["dit", "motion_controller", "vace"])
|
||||
self.load_models_to_device(["dit"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = model_fn_wan_video(
|
||||
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **image_emb, **extra_input,
|
||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
|
||||
)
|
||||
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = model_fn_wan_video(
|
||||
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **image_emb, **extra_input,
|
||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
|
||||
)
|
||||
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
if vace_reference_image is not None:
|
||||
latents = latents[:, :, 1:]
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
@@ -536,19 +358,13 @@ class TeaCache:
|
||||
|
||||
def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
motion_controller: WanMotionControllerModel = None,
|
||||
vace: VaceWanModel = None,
|
||||
x: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
context: torch.Tensor = None,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
reference_latents = None,
|
||||
vace_context = None,
|
||||
vace_scale = 1.0,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if use_unified_sequence_parallel:
|
||||
@@ -559,8 +375,6 @@ def model_fn_wan_video(
|
||||
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||
if motion_bucket_id is not None and motion_controller is not None:
|
||||
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
||||
context = dit.text_embedding(context)
|
||||
|
||||
if dit.has_image_input:
|
||||
@@ -570,12 +384,6 @@ def model_fn_wan_video(
|
||||
|
||||
x, (f, h, w) = dit.patchify(x)
|
||||
|
||||
# Reference image
|
||||
if reference_latents is not None:
|
||||
reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2)
|
||||
x = torch.concat([reference_latents, x], dim=1)
|
||||
f += 1
|
||||
|
||||
freqs = torch.cat([
|
||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
@@ -587,40 +395,22 @@ def model_fn_wan_video(
|
||||
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
if vace_context is not None:
|
||||
vace_hints = vace(x, vace_context, context, t_mod, freqs)
|
||||
|
||||
# blocks
|
||||
if use_unified_sequence_parallel:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||
x = chunks[get_sequence_parallel_rank()]
|
||||
|
||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
for block in dit.blocks:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
||||
x = x + current_vace_hint * vace_scale
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
|
||||
x = dit.head(x, t)
|
||||
if use_unified_sequence_parallel:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
x = x[:, :-pad_shape] if pad_shape > 0 else x
|
||||
# Remove reference latents
|
||||
if reference_latents is not None:
|
||||
x = x[:, reference_latents.shape[1]:]
|
||||
f -= 1
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,23 +1,10 @@
|
||||
import torch, math
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps=100,
|
||||
num_train_timesteps=1000,
|
||||
shift=3.0,
|
||||
sigma_max=1.0,
|
||||
sigma_min=0.003/1.002,
|
||||
inverse_timesteps=False,
|
||||
extra_one_step=False,
|
||||
reverse_sigmas=False,
|
||||
exponential_shift=False,
|
||||
exponential_shift_mu=None,
|
||||
shift_terminal=None,
|
||||
):
|
||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.shift = shift
|
||||
self.sigma_max = sigma_max
|
||||
@@ -25,33 +12,20 @@ class FlowMatchScheduler():
|
||||
self.inverse_timesteps = inverse_timesteps
|
||||
self.extra_one_step = extra_one_step
|
||||
self.reverse_sigmas = reverse_sigmas
|
||||
self.exponential_shift = exponential_shift
|
||||
self.exponential_shift_mu = exponential_shift_mu
|
||||
self.shift_terminal = shift_terminal
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None, random_sigmas=False):
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
|
||||
if shift is not None:
|
||||
self.shift = shift
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
if random_sigmas:
|
||||
self.sigmas = torch.Tensor(sorted([torch.rand((1,)).item() * (sigma_start - self.sigma_min) for i in range(num_inference_steps - 1)] + [sigma_start], reverse=True))
|
||||
elif self.extra_one_step:
|
||||
if self.extra_one_step:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
||||
else:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
if self.inverse_timesteps:
|
||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
||||
if self.exponential_shift:
|
||||
mu = self.calculate_shift(dynamic_shift_len) if dynamic_shift_len is not None else self.exponential_shift_mu
|
||||
self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))
|
||||
else:
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
if self.shift_terminal is not None:
|
||||
one_minus_z = 1 - self.sigmas
|
||||
scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)
|
||||
self.sigmas = 1 - (one_minus_z / scale_factor)
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
if self.reverse_sigmas:
|
||||
self.sigmas = 1 - self.sigmas
|
||||
self.timesteps = self.sigmas * self.num_train_timesteps
|
||||
@@ -61,9 +35,6 @@ class FlowMatchScheduler():
|
||||
y_shifted = y - y.min()
|
||||
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
self.training = True
|
||||
else:
|
||||
self.training = False
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||
@@ -106,17 +77,3 @@ class FlowMatchScheduler():
|
||||
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
||||
weights = self.linear_timesteps_weights[timestep_id]
|
||||
return weights
|
||||
|
||||
|
||||
def calculate_shift(
|
||||
self,
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 8192,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 0.9,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
@@ -1,541 +0,0 @@
|
||||
import imageio, os, torch, warnings, torchvision, argparse, json
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
|
||||
|
||||
class ImageDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
base_path=None, metadata_path=None,
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
data_file_keys=("image",),
|
||||
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
||||
repeat=1,
|
||||
args=None,
|
||||
):
|
||||
if args is not None:
|
||||
base_path = args.dataset_base_path
|
||||
metadata_path = args.dataset_metadata_path
|
||||
height = args.height
|
||||
width = args.width
|
||||
max_pixels = args.max_pixels
|
||||
data_file_keys = args.data_file_keys.split(",")
|
||||
repeat = args.dataset_repeat
|
||||
|
||||
self.base_path = base_path
|
||||
self.max_pixels = max_pixels
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
self.data_file_keys = data_file_keys
|
||||
self.image_file_extension = image_file_extension
|
||||
self.repeat = repeat
|
||||
|
||||
if height is not None and width is not None:
|
||||
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
|
||||
self.dynamic_resolution = False
|
||||
elif height is None and width is None:
|
||||
print("Height and width are none. Setting `dynamic_resolution` to True.")
|
||||
self.dynamic_resolution = True
|
||||
|
||||
if metadata_path is None:
|
||||
print("No metadata. Trying to generate it.")
|
||||
metadata = self.generate_metadata(base_path)
|
||||
print(f"{len(metadata)} lines in metadata.")
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
elif metadata_path.endswith(".json"):
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
self.data = metadata
|
||||
elif metadata_path.endswith(".jsonl"):
|
||||
metadata = []
|
||||
with open(metadata_path, 'r') as f:
|
||||
for line in tqdm(f):
|
||||
metadata.append(json.loads(line.strip()))
|
||||
self.data = metadata
|
||||
else:
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
|
||||
|
||||
def generate_metadata(self, folder):
|
||||
image_list, prompt_list = [], []
|
||||
file_set = set(os.listdir(folder))
|
||||
for file_name in file_set:
|
||||
if "." not in file_name:
|
||||
continue
|
||||
file_ext_name = file_name.split(".")[-1].lower()
|
||||
file_base_name = file_name[:-len(file_ext_name)-1]
|
||||
if file_ext_name not in self.image_file_extension:
|
||||
continue
|
||||
prompt_file_name = file_base_name + ".txt"
|
||||
if prompt_file_name not in file_set:
|
||||
continue
|
||||
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
|
||||
prompt = f.read().strip()
|
||||
image_list.append(file_name)
|
||||
prompt_list.append(prompt)
|
||||
metadata = pd.DataFrame()
|
||||
metadata["image"] = image_list
|
||||
metadata["prompt"] = prompt_list
|
||||
return metadata
|
||||
|
||||
|
||||
def crop_and_resize(self, image, target_height, target_width):
|
||||
width, height = image.size
|
||||
scale = max(target_width / width, target_height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||
return image
|
||||
|
||||
|
||||
def get_height_width(self, image):
|
||||
if self.dynamic_resolution:
|
||||
width, height = image.size
|
||||
if width * height > self.max_pixels:
|
||||
scale = (width * height / self.max_pixels) ** 0.5
|
||||
height, width = int(height / scale), int(width / scale)
|
||||
height = height // self.height_division_factor * self.height_division_factor
|
||||
width = width // self.width_division_factor * self.width_division_factor
|
||||
else:
|
||||
height, width = self.height, self.width
|
||||
return height, width
|
||||
|
||||
|
||||
def load_image(self, file_path):
|
||||
image = Image.open(file_path).convert("RGB")
|
||||
image = self.crop_and_resize(image, *self.get_height_width(image))
|
||||
return image
|
||||
|
||||
|
||||
def load_data(self, file_path):
|
||||
return self.load_image(file_path)
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
data = self.data[data_id % len(self.data)].copy()
|
||||
for key in self.data_file_keys:
|
||||
if key in data:
|
||||
if isinstance(data[key], list):
|
||||
path = [os.path.join(self.base_path, p) for p in data[key]]
|
||||
data[key] = [self.load_data(p) for p in path]
|
||||
else:
|
||||
path = os.path.join(self.base_path, data[key])
|
||||
data[key] = self.load_data(path)
|
||||
if data[key] is None:
|
||||
warnings.warn(f"cannot load file {data[key]}.")
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data) * self.repeat
|
||||
|
||||
|
||||
|
||||
class VideoDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
base_path=None, metadata_path=None,
|
||||
num_frames=81,
|
||||
time_division_factor=4, time_division_remainder=1,
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
data_file_keys=("video",),
|
||||
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
||||
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"),
|
||||
repeat=1,
|
||||
args=None,
|
||||
):
|
||||
if args is not None:
|
||||
base_path = args.dataset_base_path
|
||||
metadata_path = args.dataset_metadata_path
|
||||
height = args.height
|
||||
width = args.width
|
||||
max_pixels = args.max_pixels
|
||||
num_frames = args.num_frames
|
||||
data_file_keys = args.data_file_keys.split(",")
|
||||
repeat = args.dataset_repeat
|
||||
|
||||
self.base_path = base_path
|
||||
self.num_frames = num_frames
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
self.max_pixels = max_pixels
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
self.data_file_keys = data_file_keys
|
||||
self.image_file_extension = image_file_extension
|
||||
self.video_file_extension = video_file_extension
|
||||
self.repeat = repeat
|
||||
|
||||
if height is not None and width is not None:
|
||||
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
|
||||
self.dynamic_resolution = False
|
||||
elif height is None and width is None:
|
||||
print("Height and width are none. Setting `dynamic_resolution` to True.")
|
||||
self.dynamic_resolution = True
|
||||
|
||||
if metadata_path is None:
|
||||
print("No metadata. Trying to generate it.")
|
||||
metadata = self.generate_metadata(base_path)
|
||||
print(f"{len(metadata)} lines in metadata.")
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
elif metadata_path.endswith(".json"):
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
self.data = metadata
|
||||
else:
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
|
||||
|
||||
def generate_metadata(self, folder):
|
||||
video_list, prompt_list = [], []
|
||||
file_set = set(os.listdir(folder))
|
||||
for file_name in file_set:
|
||||
if "." not in file_name:
|
||||
continue
|
||||
file_ext_name = file_name.split(".")[-1].lower()
|
||||
file_base_name = file_name[:-len(file_ext_name)-1]
|
||||
if file_ext_name not in self.image_file_extension and file_ext_name not in self.video_file_extension:
|
||||
continue
|
||||
prompt_file_name = file_base_name + ".txt"
|
||||
if prompt_file_name not in file_set:
|
||||
continue
|
||||
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
|
||||
prompt = f.read().strip()
|
||||
video_list.append(file_name)
|
||||
prompt_list.append(prompt)
|
||||
metadata = pd.DataFrame()
|
||||
metadata["video"] = video_list
|
||||
metadata["prompt"] = prompt_list
|
||||
return metadata
|
||||
|
||||
|
||||
def crop_and_resize(self, image, target_height, target_width):
|
||||
width, height = image.size
|
||||
scale = max(target_width / width, target_height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||
return image
|
||||
|
||||
|
||||
def get_height_width(self, image):
|
||||
if self.dynamic_resolution:
|
||||
width, height = image.size
|
||||
if width * height > self.max_pixels:
|
||||
scale = (width * height / self.max_pixels) ** 0.5
|
||||
height, width = int(height / scale), int(width / scale)
|
||||
height = height // self.height_division_factor * self.height_division_factor
|
||||
width = width // self.width_division_factor * self.width_division_factor
|
||||
else:
|
||||
height, width = self.height, self.width
|
||||
return height, width
|
||||
|
||||
|
||||
def get_num_frames(self, reader):
|
||||
num_frames = self.num_frames
|
||||
if int(reader.count_frames()) < num_frames:
|
||||
num_frames = int(reader.count_frames())
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
|
||||
def load_video(self, file_path):
|
||||
reader = imageio.get_reader(file_path)
|
||||
num_frames = self.get_num_frames(reader)
|
||||
frames = []
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(frame_id)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
return frames
|
||||
|
||||
|
||||
def load_image(self, file_path):
|
||||
image = Image.open(file_path).convert("RGB")
|
||||
image = self.crop_and_resize(image, *self.get_height_width(image))
|
||||
frames = [image]
|
||||
return frames
|
||||
|
||||
|
||||
def is_image(self, file_path):
|
||||
file_ext_name = file_path.split(".")[-1]
|
||||
return file_ext_name.lower() in self.image_file_extension
|
||||
|
||||
|
||||
def is_video(self, file_path):
|
||||
file_ext_name = file_path.split(".")[-1]
|
||||
return file_ext_name.lower() in self.video_file_extension
|
||||
|
||||
|
||||
def load_data(self, file_path):
|
||||
if self.is_image(file_path):
|
||||
return self.load_image(file_path)
|
||||
elif self.is_video(file_path):
|
||||
return self.load_video(file_path)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
data = self.data[data_id % len(self.data)].copy()
|
||||
for key in self.data_file_keys:
|
||||
if key in data:
|
||||
path = os.path.join(self.base_path, data[key])
|
||||
data[key] = self.load_data(path)
|
||||
if data[key] is None:
|
||||
warnings.warn(f"cannot load file {data[key]}.")
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data) * self.repeat
|
||||
|
||||
|
||||
|
||||
class DiffusionTrainingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
for name, model in self.named_children():
|
||||
model.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def trainable_modules(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
|
||||
return trainable_modules
|
||||
|
||||
|
||||
def trainable_param_names(self):
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
return trainable_param_names
|
||||
|
||||
|
||||
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
|
||||
if lora_alpha is None:
|
||||
lora_alpha = lora_rank
|
||||
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
return model
|
||||
|
||||
|
||||
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
|
||||
trainable_param_names = self.trainable_param_names()
|
||||
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
||||
if remove_prefix is not None:
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith(remove_prefix):
|
||||
name = name[len(remove_prefix):]
|
||||
state_dict_[name] = param
|
||||
state_dict = state_dict_
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
class ModelLogger:
|
||||
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
|
||||
self.output_path = output_path
|
||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||
self.state_dict_converter = state_dict_converter
|
||||
self.num_steps = 0
|
||||
|
||||
|
||||
def on_step_end(self, accelerator, model, save_steps=None):
|
||||
self.num_steps += 1
|
||||
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
|
||||
def on_epoch_end(self, accelerator, model, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
def on_training_end(self, accelerator, model, save_steps=None):
|
||||
if save_steps is not None and self.num_steps % save_steps != 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
|
||||
def save_model(self, accelerator, model, file_name):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, file_name)
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
def launch_training_task(
|
||||
dataset: torch.utils.data.Dataset,
|
||||
model: DiffusionTrainingModule,
|
||||
model_logger: ModelLogger,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
||||
num_workers: int = 8,
|
||||
save_steps: int = None,
|
||||
num_epochs: int = 1,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
find_unused_parameters: bool = False,
|
||||
):
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)],
|
||||
)
|
||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||
|
||||
for epoch_id in range(num_epochs):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_step_end(accelerator, model, save_steps)
|
||||
scheduler.step()
|
||||
if save_steps is None:
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
model_logger.on_training_end(accelerator, model, save_steps)
|
||||
|
||||
|
||||
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0])
|
||||
accelerator = Accelerator()
|
||||
model, dataloader = accelerator.prepare(model, dataloader)
|
||||
os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True)
|
||||
for data_id, data in enumerate(tqdm(dataloader)):
|
||||
with torch.no_grad():
|
||||
inputs = model.forward_preprocess(data)
|
||||
inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs}
|
||||
torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth"))
|
||||
|
||||
|
||||
|
||||
def wan_parser():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||
parser.add_argument("--max_pixels", type=int, default=1280*720, help="Maximum number of pixels per frame, used for dynamic resolution..")
|
||||
parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
||||
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
def flux_parser():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution..")
|
||||
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.")
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
|
||||
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
def qwen_image_parser():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution..")
|
||||
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.")
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Paths to tokenizer.")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
|
||||
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
return parser
|
||||
@@ -1,261 +0,0 @@
|
||||
import torch, warnings, glob, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import repeat, reduce
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from modelscope import snapshot_download
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BasePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cuda", torch_dtype=torch.float16,
|
||||
height_division_factor=64, width_division_factor=64,
|
||||
time_division_factor=None, time_division_remainder=None,
|
||||
):
|
||||
super().__init__()
|
||||
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# The following parameters are used for shape check.
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
self.vram_management_enabled = False
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def check_resize_height_width(self, height, width, num_frames=None):
|
||||
# Shape check
|
||||
if height % self.height_division_factor != 0:
|
||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||
if width % self.width_division_factor != 0:
|
||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||
if num_frames is None:
|
||||
return height, width
|
||||
else:
|
||||
if num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
||||
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||
return height, width, num_frames
|
||||
|
||||
|
||||
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
||||
# Transform a PIL.Image to torch.Tensor
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32))
|
||||
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||
image = image * ((max_value - min_value) / 255) + min_value
|
||||
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
||||
# Transform a list of PIL.Image to torch.Tensor
|
||||
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
||||
video = torch.stack(video, dim=pattern.index("T") // 2)
|
||||
return video
|
||||
|
||||
|
||||
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
||||
# Transform a torch.Tensor to PIL.Image
|
||||
if pattern != "H W C":
|
||||
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
||||
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
||||
image = image.to(device="cpu", dtype=torch.uint8)
|
||||
image = Image.fromarray(image.numpy())
|
||||
return image
|
||||
|
||||
|
||||
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
||||
# Transform a torch.Tensor to list of PIL.Image
|
||||
if pattern != "T H W C":
|
||||
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
||||
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
||||
return video
|
||||
|
||||
|
||||
def load_models_to_device(self, model_names=[]):
|
||||
if self.vram_management_enabled:
|
||||
# offload models
|
||||
for name, model in self.named_children():
|
||||
if name not in model_names:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "offload"):
|
||||
module.offload()
|
||||
else:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
# onload models
|
||||
for name, model in self.named_children():
|
||||
if name in model_names:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "onload"):
|
||||
module.onload()
|
||||
else:
|
||||
model.to(self.device)
|
||||
|
||||
|
||||
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
||||
# Initialize Gaussian noise
|
||||
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
||||
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||
return noise
|
||||
|
||||
|
||||
def enable_cpu_offload(self):
|
||||
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
|
||||
self.vram_management_enabled = True
|
||||
|
||||
|
||||
def get_vram(self):
|
||||
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
||||
|
||||
|
||||
def freeze_except(self, model_names):
|
||||
for name, model in self.named_children():
|
||||
if name in model_names:
|
||||
model.train()
|
||||
model.requires_grad_(True)
|
||||
else:
|
||||
model.eval()
|
||||
model.requires_grad_(False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
path: Union[str, list[str]] = None
|
||||
model_id: str = None
|
||||
origin_file_pattern: Union[str, list[str]] = None
|
||||
download_resource: str = "ModelScope"
|
||||
offload_device: Optional[Union[str, torch.device]] = None
|
||||
offload_dtype: Optional[torch.dtype] = None
|
||||
local_model_path: str = None
|
||||
skip_download: bool = False
|
||||
|
||||
def download_if_necessary(self, use_usp=False):
|
||||
if self.path is None:
|
||||
# Check model_id and origin_file_pattern
|
||||
if self.model_id is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
|
||||
|
||||
# Skip if not in rank 0
|
||||
if use_usp:
|
||||
import torch.distributed as dist
|
||||
skip_download = self.skip_download or dist.get_rank() != 0
|
||||
else:
|
||||
skip_download = self.skip_download
|
||||
|
||||
# Check whether the origin path is a folder
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
self.origin_file_pattern = ""
|
||||
allow_file_pattern = None
|
||||
is_folder = True
|
||||
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
|
||||
allow_file_pattern = self.origin_file_pattern + "*"
|
||||
is_folder = True
|
||||
else:
|
||||
allow_file_pattern = self.origin_file_pattern
|
||||
is_folder = False
|
||||
|
||||
# Download
|
||||
if self.local_model_path is None:
|
||||
self.local_model_path = "./models"
|
||||
if not skip_download:
|
||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
|
||||
# Let rank 1, 2, ... wait for rank 0
|
||||
if use_usp:
|
||||
import torch.distributed as dist
|
||||
dist.barrier(device_ids=[dist.get_rank()])
|
||||
|
||||
# Return downloaded files
|
||||
if is_folder:
|
||||
self.path = os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)
|
||||
else:
|
||||
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||
if isinstance(self.path, list) and len(self.path) == 1:
|
||||
self.path = self.path[0]
|
||||
|
||||
|
||||
|
||||
class PipelineUnit:
|
||||
def __init__(
|
||||
self,
|
||||
seperate_cfg: bool = False,
|
||||
take_over: bool = False,
|
||||
input_params: tuple[str] = None,
|
||||
input_params_posi: dict[str, str] = None,
|
||||
input_params_nega: dict[str, str] = None,
|
||||
onload_model_names: tuple[str] = None
|
||||
):
|
||||
self.seperate_cfg = seperate_cfg
|
||||
self.take_over = take_over
|
||||
self.input_params = input_params
|
||||
self.input_params_posi = input_params_posi
|
||||
self.input_params_nega = input_params_nega
|
||||
self.onload_model_names = onload_model_names
|
||||
|
||||
|
||||
def process(self, pipe: BasePipeline, inputs: dict, positive=True, **kwargs) -> dict:
|
||||
raise NotImplementedError("`process` is not implemented.")
|
||||
|
||||
|
||||
|
||||
class PipelineUnitRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
||||
if unit.take_over:
|
||||
# Let the pipeline unit take over this function.
|
||||
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
||||
elif unit.seperate_cfg:
|
||||
# Positive side
|
||||
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_posi.update(processor_outputs)
|
||||
# Negative side
|
||||
if inputs_shared["cfg_scale"] != 1:
|
||||
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_shared.update(processor_outputs)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
@@ -1,2 +1 @@
|
||||
from .layers import *
|
||||
from .gradient_checkpointing import *
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs, **kwargs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
|
||||
|
||||
def gradient_checkpoint_forward(
|
||||
model,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(model),
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(model),
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
model_output = model(*args, **kwargs)
|
||||
return model_output
|
||||
@@ -8,33 +8,8 @@ def cast_to(weight, dtype, device):
|
||||
return r
|
||||
|
||||
|
||||
class AutoTorchModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def check_free_vram(self):
|
||||
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
|
||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
|
||||
return used_memory < self.vram_limit
|
||||
|
||||
def offload(self):
|
||||
if self.state != 0:
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state != 1:
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def keep(self):
|
||||
if self.state != 2:
|
||||
self.to(dtype=self.computation_dtype, device=self.computation_device)
|
||||
self.state = 2
|
||||
|
||||
|
||||
class AutoWrappedModule(AutoTorchModule):
|
||||
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs):
|
||||
class AutoWrappedModule(torch.nn.Module):
|
||||
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
super().__init__()
|
||||
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
||||
self.offload_dtype = offload_dtype
|
||||
@@ -43,57 +18,28 @@ class AutoWrappedModule(AutoTorchModule):
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.vram_limit = vram_limit
|
||||
self.state = 0
|
||||
|
||||
def offload(self):
|
||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.state == 2:
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
module = self.module
|
||||
else:
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
module = self.module
|
||||
elif self.vram_limit is not None and self.check_free_vram():
|
||||
self.keep()
|
||||
module = self.module
|
||||
else:
|
||||
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
||||
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
||||
return module(*args, **kwargs)
|
||||
|
||||
|
||||
class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule):
|
||||
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs):
|
||||
with init_weights_on_device(device=torch.device("meta")):
|
||||
super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||
self.weight = module.weight
|
||||
self.bias = module.bias
|
||||
self.offload_dtype = offload_dtype
|
||||
self.offload_device = offload_device
|
||||
self.onload_dtype = onload_dtype
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.vram_limit = vram_limit
|
||||
self.state = 0
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if self.state == 2:
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
weight, bias = self.weight, self.bias
|
||||
elif self.vram_limit is not None and self.check_free_vram():
|
||||
self.keep()
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
with torch.amp.autocast(device_type=x.device.type):
|
||||
x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x)
|
||||
return x
|
||||
|
||||
|
||||
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, name="", **kwargs):
|
||||
class AutoWrappedLinear(torch.nn.Linear):
|
||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
with init_weights_on_device(device=torch.device("meta")):
|
||||
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||
self.weight = module.weight
|
||||
@@ -104,93 +50,29 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.vram_limit = vram_limit
|
||||
self.state = 0
|
||||
self.name = name
|
||||
self.lora_A_weights = []
|
||||
self.lora_B_weights = []
|
||||
self.lora_merger = None
|
||||
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
||||
|
||||
def fp8_linear(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = input.device
|
||||
origin_dtype = input.dtype
|
||||
origin_shape = input.shape
|
||||
input = input.reshape(-1, origin_shape[-1])
|
||||
|
||||
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
|
||||
fp8_max = 448.0
|
||||
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
|
||||
# To avoid overflow and ensure numerical compatibility during FP8 computation,
|
||||
# we scale down the input by 2.0 in advance.
|
||||
# This scaling will be compensated later during the final result scaling.
|
||||
if self.computation_dtype == torch.float8_e4m3fnuz:
|
||||
fp8_max = fp8_max / 2.0
|
||||
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
|
||||
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
|
||||
input = input / (scale_a + 1e-8)
|
||||
input = input.to(self.computation_dtype)
|
||||
weight = weight.to(self.computation_dtype)
|
||||
bias = bias.to(torch.bfloat16)
|
||||
def offload(self):
|
||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
result = torch._scaled_mm(
|
||||
input,
|
||||
weight.T,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b.T,
|
||||
bias=bias,
|
||||
out_dtype=origin_dtype,
|
||||
)
|
||||
new_shape = origin_shape[:-1] + result.shape[-1:]
|
||||
result = result.reshape(new_shape)
|
||||
return result
|
||||
def onload(self):
|
||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
# VRAM management
|
||||
if self.state == 2:
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
weight, bias = self.weight, self.bias
|
||||
elif self.vram_limit is not None and self.check_free_vram():
|
||||
self.keep()
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
|
||||
# Linear forward
|
||||
if self.enable_fp8:
|
||||
out = self.fp8_linear(x, weight, bias)
|
||||
else:
|
||||
out = torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
# LoRA
|
||||
if len(self.lora_A_weights) == 0:
|
||||
# No LoRA
|
||||
return out
|
||||
elif self.lora_merger is None:
|
||||
# Native LoRA inference
|
||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||
out = out + x @ lora_A.T @ lora_B.T
|
||||
else:
|
||||
# LoRA fusion
|
||||
lora_output = []
|
||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||
lora_output.append(x @ lora_A.T @ lora_B.T)
|
||||
lora_output = torch.stack(lora_output)
|
||||
out = self.lora_merger(out, lora_output)
|
||||
return out
|
||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""):
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
||||
for name, module in model.named_children():
|
||||
layer_name = name if name_prefix == "" else name_prefix + "." + name
|
||||
for source_module, target_module in module_map.items():
|
||||
if isinstance(module, source_module):
|
||||
num_param = sum(p.numel() for p in module.parameters())
|
||||
@@ -198,16 +80,16 @@ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict,
|
||||
module_config_ = overflow_module_config
|
||||
else:
|
||||
module_config_ = module_config
|
||||
module_ = target_module(module, **module_config_, vram_limit=vram_limit, name=layer_name)
|
||||
module_ = target_module(module, **module_config_)
|
||||
setattr(model, name, module_)
|
||||
total_num_param += num_param
|
||||
break
|
||||
else:
|
||||
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit, name_prefix=layer_name)
|
||||
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
|
||||
return total_num_param
|
||||
|
||||
|
||||
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, vram_limit=None):
|
||||
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0, vram_limit=vram_limit)
|
||||
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
|
||||
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
|
||||
model.vram_management_enabled = True
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# CogVideoX
|
||||
|
||||
### Example: Text-to-Video using CogVideoX-5B (Experimental)
|
||||
|
||||
See [cogvideo_text_to_video.py](cogvideo_text_to_video.py).
|
||||
|
||||
First, we generate a video using prompt "an astronaut riding a horse on Mars".
|
||||
|
||||
https://github.com/user-attachments/assets/4c91c1cd-e4a0-471a-bd8d-24d761262941
|
||||
|
||||
Then, we convert the astronaut to a robot.
|
||||
|
||||
https://github.com/user-attachments/assets/225a00a4-2bc8-4740-8e86-a64b460a29ec
|
||||
|
||||
Upscale the video using the model itself.
|
||||
|
||||
https://github.com/user-attachments/assets/c02cb30c-de60-473c-8242-32c67b3155ad
|
||||
|
||||
Make the video look smoother by interpolating frames.
|
||||
|
||||
https://github.com/user-attachments/assets/f0e465b4-45df-4435-ab10-7a084ca2b0a0
|
||||
|
||||
Here is another example.
|
||||
|
||||
First, we generate a video using prompt "a dog is running".
|
||||
|
||||
https://github.com/user-attachments/assets/e3696297-99f5-4d0c-a5ca-1d1566db85b4
|
||||
|
||||
Then, we add a blue collar to the dog.
|
||||
|
||||
https://github.com/user-attachments/assets/7ff22be7-4390-4d33-ae6c-53f6f056e18d
|
||||
|
||||
Upscale the video using the model itself.
|
||||
|
||||
https://github.com/user-attachments/assets/a909c32c-0b7d-495c-a53c-d23a99a3d3e9
|
||||
|
||||
Make the video look smoother by interpolating frames.
|
||||
|
||||
https://github.com/user-attachments/assets/ea37c150-97a0-4858-8003-0c2e5eef3331
|
||||
@@ -1,393 +0,0 @@
|
||||
# FLUX
|
||||
|
||||
[切换到中文](./README_zh.md)
|
||||
|
||||
FLUX is a series of image generation models open-sourced by Black-Forest-Labs.
|
||||
|
||||
**DiffSynth-Studio has introduced a new inference and training framework. If you need to use the old version, please click [here](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c).**
|
||||
|
||||
## Installation
|
||||
|
||||
Before using these models, please install DiffSynth-Studio from source code:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
You can quickly load the [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev ) model and run inference by executing the code below.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|
||||
|Model ID|Extra Args|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_inference_low_vram/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](./model_inference/FLUX.1-Krea-dev.py)|[code](./model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](./model_training/full/FLUX.1-Krea-dev.sh)|[code](./model_training/validate_full/FLUX.1-Krea-dev.py)|[code](./model_training/lora/FLUX.1-Krea-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
||||
|[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||
|[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||
|[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||
|[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||
|[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||
|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./model_inference/FLUX.1-dev-EliGen.py)|[code](./model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||
|[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||
|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||
|[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./model_inference/Step1X-Edit.py)|[code](./model_inference_low_vram/Step1X-Edit.py)|[code](./model_training/full/Step1X-Edit.sh)|[code](./model_training/validate_full/Step1X-Edit.py)|[code](./model_training/lora/Step1X-Edit.sh)|[code](./model_training/validate_lora/Step1X-Edit.py)|
|
||||
|[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./model_inference/FLEX.2-preview.py)|[code](./model_inference_low_vram/FLEX.2-preview.py)|[code](./model_training/full/FLEX.2-preview.sh)|[code](./model_training/validate_full/FLEX.2-preview.py)|[code](./model_training/lora/FLEX.2-preview.sh)|[code](./model_training/validate_lora/FLEX.2-preview.py)|
|
||||
|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./model_inference/Nexus-Gen-Editing.py)|[code](./model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./model_training/full/Nexus-Gen.sh)|[code](./model_training/validate_full/Nexus-Gen.py)|[code](./model_training/lora/Nexus-Gen.sh)|[code](./model_training/validate_lora/Nexus-Gen.py)|
|
||||
|
||||
## Model Inference
|
||||
|
||||
The following sections will help you understand our features and write inference code.
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Load Model</summary>
|
||||
|
||||
The model is loaded using `from_pretrained`:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
Here, `torch_dtype` and `device` set the computation precision and device. The `model_configs` can be used in different ways to specify model paths:
|
||||
|
||||
* Download the model from [ModelScope](https://modelscope.cn/ ) and load it. In this case, fill in `model_id` and `origin_file_pattern`, for example:
|
||||
|
||||
```python
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
||||
```
|
||||
|
||||
* Load the model from a local file path. In this case, fill in `path`, for example:
|
||||
|
||||
```python
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors")
|
||||
```
|
||||
|
||||
For a single model that loads from multiple files, use a list, for example:
|
||||
|
||||
```python
|
||||
ModelConfig(path=[
|
||||
"models/xxx/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"models/xxx/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"models/xxx/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
])
|
||||
```
|
||||
|
||||
The `ModelConfig` method also provides extra arguments to control model loading behavior:
|
||||
|
||||
* `local_model_path`: Path to save downloaded models. Default is `"./models"`.
|
||||
* `skip_download`: Whether to skip downloading. Default is `False`. If your network cannot access [ModelScope](https://modelscope.cn/ ), download the required files manually and set this to `True`.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>VRAM Management</summary>
|
||||
|
||||
DiffSynth-Studio provides fine-grained VRAM management for the FLUX model. This allows the model to run on devices with low VRAM. You can enable the offload feature using the code below. It moves some modules to CPU memory when GPU memory is limited.
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
FP8 quantization is also supported:
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
You can use FP8 quantization and offload at the same time:
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
After enabling VRAM management, the framework will automatically decide the VRAM strategy based on available GPU memory. For most FLUX models, inference can run with as little as 8GB of VRAM. The `enable_vram_management` function has the following parameters to manually control the VRAM strategy:
|
||||
|
||||
* `vram_limit`: VRAM usage limit in GB. By default, it uses all free VRAM on the device. Note that this is not an absolute limit. If the set VRAM is not enough but more VRAM is actually available, the model will run with minimal VRAM usage. Setting it to 0 achieves the theoretical minimum VRAM usage.
|
||||
* `vram_buffer`: VRAM buffer size in GB. Default is 0.5GB. A buffer is needed because larger neural network layers may use more VRAM than expected during loading. The optimal value is the VRAM used by the largest layer in the model.
|
||||
* `num_persistent_param_in_dit`: Number of parameters in the DiT model that stay in VRAM. Default is no limit. We plan to remove this parameter in the future. Do not rely on it.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Inference Acceleration</summary>
|
||||
|
||||
* TeaCache: Acceleration technique [TeaCache](https://github.com/ali-vilab/TeaCache ). Please refer to the [example code](./acceleration/teacache.py).
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Input Parameters</summary>
|
||||
|
||||
The pipeline supports the following input parameters during inference:
|
||||
|
||||
* `prompt`: Text prompt describing what should appear in the image.
|
||||
* `negative_prompt`: Negative prompt describing what should not appear in the image. Default is `""`.
|
||||
* `cfg_scale`: Parameter for classifier-free guidance. Default is 1. Takes effect when set to a value greater than 1.
|
||||
* `embedded_guidance`: Built-in guidance parameter for FLUX-dev. Default is 3.5.
|
||||
* `t5_sequence_length`: Sequence length of text embeddings from the T5 model. Default is 512.
|
||||
* `input_image`: Input image used for image-to-image generation. Used together with `denoising_strength`.
|
||||
* `denoising_strength`: Denoising strength, range from 0 to 1. Default is 1. When close to 0, the output image is similar to the input. When close to 1, the output differs more from the input. Do not set it to values other than 1 if `input_image` is not provided.
|
||||
* `height`: Image height. Must be a multiple of 16.
|
||||
* `width`: Image width. Must be a multiple of 16.
|
||||
* `seed`: Random seed. Default is `None`, meaning fully random.
|
||||
* `rand_device`: Device for generating random Gaussian noise. Default is `"cpu"`. Setting it to `"cuda"` may lead to different results on different GPUs.
|
||||
* `sigma_shift`: Parameter from Rectified Flow theory. Default is 3. A larger value means the model spends more steps at the start of denoising. Increasing this can improve image quality, but may cause differences between generated images and training data due to inconsistency with training.
|
||||
* `num_inference_steps`: Number of inference steps. Default is 30.
|
||||
* `kontext_images`: Input images for the Kontext model.
|
||||
* `controlnet_inputs`: Inputs for the ControlNet model.
|
||||
* `ipadapter_images`: Input images for the IP-Adapter model.
|
||||
* `ipadapter_scale`: Control strength for the IP-Adapter model.
|
||||
* `eligen_entity_prompts`: Local prompts for the EliGen model.
|
||||
* `eligen_entity_masks`: Mask regions for local prompts in the EliGen model. Matches one-to-one with `eligen_entity_prompts`.
|
||||
* `eligen_enable_on_negative`: Whether to enable EliGen on the negative prompt side. Only works when `cfg_scale > 1`.
|
||||
* `eligen_enable_inpaint`: Whether to enable EliGen for local inpainting.
|
||||
* `infinityou_id_image`: Face image for the InfiniteYou model.
|
||||
* `infinityou_guidance`: Control strength for the InfiniteYou model.
|
||||
* `flex_inpaint_image`: Image for FLEX model's inpainting.
|
||||
* `flex_inpaint_mask`: Mask region for FLEX model's inpainting.
|
||||
* `flex_control_image`: Image for FLEX model's structural control.
|
||||
* `flex_control_strength`: Strength for FLEX model's structural control.
|
||||
* `flex_control_stop`: End point for FLEX model's structural control. 1 means enabled throughout, 0.5 means enabled in the first half, 0 means disabled.
|
||||
* `step1x_reference_image`: Input image for Step1x-Edit model's image editing.
|
||||
* `lora_encoder_inputs`: Inputs for LoRA encoder. Can be ModelConfig or local path.
|
||||
* `lora_encoder_scale`: Activation strength for LoRA encoder. Default is 1. Smaller values mean weaker LoRA activation.
|
||||
* `tea_cache_l1_thresh`: Threshold for TeaCache. Larger values mean faster speed but lower image quality. Note that after enabling TeaCache, inference speed is not uniform, so the remaining time shown in the progress bar will be inaccurate.
|
||||
* `tiled`: Whether to enable tiled VAE inference. Default is `False`. Setting to `True` reduces VRAM usage during VAE encoding/decoding, with slight error and slightly longer inference time.
|
||||
* `tile_size`: Tile size during VAE encoding/decoding. Default is 128. Only takes effect when `tiled=True`.
|
||||
* `tile_stride`: Tile stride during VAE encoding/decoding. Default is 64. Only takes effect when `tiled=True`. Must be less than or equal to `tile_size`.
|
||||
* `progress_bar_cmd`: Progress bar display. Default is `tqdm.tqdm`. Set to `lambda x:x` to disable the progress bar.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## Model Training
|
||||
|
||||
Training for the FLUX series models is done using a unified script [`./model_training/train.py`](./model_training/train.py).
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Script Parameters</summary>
|
||||
|
||||
The script includes the following parameters:
|
||||
|
||||
* Dataset
|
||||
* `--dataset_base_path`: Root path of the dataset.
|
||||
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||
* `--max_pixels`: Maximum pixel area. Default is 1024*1024. When dynamic resolution is enabled, any image with resolution higher than this will be downscaled.
|
||||
* `--height`: Height of the image or video. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--data_file_keys`: Data file keys in the metadata. Separate with commas.
|
||||
* `--dataset_repeat`: Number of times the dataset repeats per epoch.
|
||||
* `--dataset_num_workers`: Number of workers for data loading.
|
||||
* Model
|
||||
* `--model_paths`: Paths to load models. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas.
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
|
||||
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
|
||||
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
|
||||
* Trainable Modules
|
||||
* `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* Extra Model Inputs
|
||||
* `--extra_inputs`: Extra model inputs, separated by commas.
|
||||
* VRAM Management
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Others
|
||||
* `--align_to_opensource_format`: Whether to align the FLUX DiT LoRA format with the open-source version. Only works for LoRA training.
|
||||
|
||||
In addition, the training framework is built on [`accelerate`](https://huggingface.co/docs/accelerate/index ). Run `accelerate config` before training to set GPU-related parameters. For some training scripts (e.g., full model training), we provide suggested `accelerate` config files. You can find them in the corresponding training scripts.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 1: Prepare Dataset</summary>
|
||||
|
||||
A dataset contains a series of files. We suggest organizing your dataset like this:
|
||||
|
||||
```
|
||||
data/example_image_dataset/
|
||||
├── metadata.csv
|
||||
├── image1.jpg
|
||||
└── image2.jpg
|
||||
```
|
||||
|
||||
Here, `image1.jpg` and `image2.jpg` are training images, and `metadata.csv` is the metadata list, for example:
|
||||
|
||||
```
|
||||
image,prompt
|
||||
image1.jpg,"a cat is sleeping"
|
||||
image2.jpg,"a dog is running"
|
||||
```
|
||||
|
||||
We have built a sample image dataset to help you test. You can download it with the following command:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
The dataset supports multiple image formats: `"jpg", "jpeg", "png", "webp"`.
|
||||
|
||||
Image size can be controlled by script arguments `--height` and `--width`. When `--height` and `--width` are left empty, dynamic resolution is enabled. The model will train using each image's actual width and height from the dataset.
|
||||
|
||||
**We strongly recommend using fixed resolution for training, because there can be load balancing issues in multi-GPU training.**
|
||||
|
||||
When the model needs extra inputs, for example, `kontext_images` required by controllable models like [`black-forest-labs/FLUX.1-Kontext-dev`](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev ), add the corresponding column to your dataset, for example:
|
||||
|
||||
```
|
||||
image,prompt,kontext_images
|
||||
image1.jpg,"a cat is sleeping",image1_reference.jpg
|
||||
```
|
||||
|
||||
If an extra input includes image files, you must specify the column name in the `--data_file_keys` argument. Add column names as needed, for example `--data_file_keys "image,kontext_images"`, and also enable `--extra_inputs "kontext_images"`.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 2: Load Model</summary>
|
||||
|
||||
Similar to model loading during inference, you can configure which models to load directly using model IDs. For example, during inference we load the model with this setting:
|
||||
|
||||
```python
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
]
|
||||
```
|
||||
|
||||
Then, during training, use the following parameter to load the same models:
|
||||
|
||||
```shell
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors"
|
||||
```
|
||||
|
||||
If you want to load models from local files, for example, during inference:
|
||||
|
||||
```python
|
||||
model_configs=[
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors"),
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors"),
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder_2/"),
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/ae.safetensors"),
|
||||
]
|
||||
```
|
||||
|
||||
Then during training, set it as:
|
||||
|
||||
```shell
|
||||
--model_paths '[
|
||||
"models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors",
|
||||
"models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/black-forest-labs/FLUX.1-dev/text_encoder_2/",
|
||||
"models/black-forest-labs/FLUX.1-dev/ae.safetensors"
|
||||
]' \
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 3: Set Trainable Modules</summary>
|
||||
|
||||
The training framework supports training base models or LoRA models. Here are some examples:
|
||||
|
||||
* Full training of the DiT part: `--trainable_models dit`
|
||||
* Training a LoRA model on the DiT part: `--lora_base_model dit --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" --lora_rank 32`
|
||||
|
||||
Also, because the training script loads multiple modules (text encoder, dit, vae), you need to remove prefixes when saving model files. For example, when fully training the DiT part or training a LoRA model on the DiT part, set `--remove_prefix_in_ckpt pipe.dit.`
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 4: Start Training</summary>
|
||||
|
||||
We have written training commands for each model. Please refer to the table at the beginning of this document.
|
||||
|
||||
</details>
|
||||
@@ -1,394 +0,0 @@
|
||||
# FLUX
|
||||
|
||||
[Switch to English](./README.md)
|
||||
|
||||
FLUX 是由 Black-Forest-Labs 开源的一系列图像生成模型。
|
||||
|
||||
**DiffSynth-Studio 启用了新的推理和训练框架,如需使用旧版本,请点击[这里](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c)。**
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本系列模型之前,请通过源码安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
通过运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_inference_low_vram/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](./model_inference/FLUX.1-Krea-dev.py)|[code](./model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](./model_training/full/FLUX.1-Krea-dev.sh)|[code](./model_training/validate_full/FLUX.1-Krea-dev.py)|[code](./model_training/lora/FLUX.1-Krea-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
||||
|[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||
|[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||
|[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||
|[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||
|[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||
|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./model_inference/FLUX.1-dev-EliGen.py)|[code](./model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||
|[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||
|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||
|[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./model_inference/Step1X-Edit.py)|[code](./model_inference_low_vram/Step1X-Edit.py)|[code](./model_training/full/Step1X-Edit.sh)|[code](./model_training/validate_full/Step1X-Edit.py)|[code](./model_training/lora/Step1X-Edit.sh)|[code](./model_training/validate_lora/Step1X-Edit.py)|
|
||||
|[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./model_inference/FLEX.2-preview.py)|[code](./model_inference_low_vram/FLEX.2-preview.py)|[code](./model_training/full/FLEX.2-preview.sh)|[code](./model_training/validate_full/FLEX.2-preview.py)|[code](./model_training/lora/FLEX.2-preview.sh)|[code](./model_training/validate_lora/FLEX.2-preview.py)|
|
||||
|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./model_inference/Nexus-Gen-Editing.py)|[code](./model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./model_training/full/Nexus-Gen.sh)|[code](./model_training/validate_full/Nexus-Gen.py)|[code](./model_training/lora/Nexus-Gen.sh)|[code](./model_training/validate_lora/Nexus-Gen.py)|
|
||||
|
||||
## 模型推理
|
||||
|
||||
以下部分将会帮助您理解我们的功能并编写推理代码。
|
||||
|
||||
<details>
|
||||
|
||||
<summary>加载模型</summary>
|
||||
|
||||
模型通过 `from_pretrained` 加载:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
其中 `torch_dtype` 和 `device` 是计算精度和计算设备。`model_configs` 可通过多种方式配置模型路径:
|
||||
|
||||
* 从[魔搭社区](https://modelscope.cn/)下载模型并加载。此时需要填写 `model_id` 和 `origin_file_pattern`,例如
|
||||
|
||||
```python
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
||||
```
|
||||
|
||||
* 从本地文件路径加载模型。此时需要填写 `path`,例如
|
||||
|
||||
```python
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors")
|
||||
```
|
||||
|
||||
对于从多个文件加载的单一模型,使用列表即可,例如
|
||||
|
||||
```python
|
||||
ModelConfig(path=[
|
||||
"models/xxx/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"models/xxx/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"models/xxx/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
])
|
||||
```
|
||||
|
||||
`ModelConfig` 还提供了额外的参数用于控制模型加载时的行为:
|
||||
|
||||
* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。
|
||||
* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>显存管理</summary>
|
||||
|
||||
DiffSynth-Studio 为 FLUX 模型提供了细粒度的显存管理,让模型能够在低显存设备上进行推理,可通过以下代码开启 offload 功能,在显存有限的设备上将部分模块 offload 到内存中。
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
FP8 量化功能也是支持的:
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
FP8 量化和 offload 可同时开启:
|
||||
|
||||
```python
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
开启显存管理后,框架会自动根据设备上的剩余显存确定显存管理策略。对于大多数 FLUX 系列模型,最低 8GB 显存即可进行推理。`enable_vram_management` 函数提供了以下参数,用于手动控制显存管理策略:
|
||||
|
||||
* `vram_limit`: 显存占用量限制(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。将其设置为0时,将会实现理论最小显存占用。
|
||||
* `vram_buffer`: 显存缓冲区大小(GB),默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。
|
||||
* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>推理加速</summary>
|
||||
|
||||
* TeaCache:加速技术 [TeaCache](https://github.com/ali-vilab/TeaCache),请参考[示例代码](./acceleration/teacache.py)。
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>输入参数</summary>
|
||||
|
||||
Pipeline 在推理阶段能够接收以下输入参数:
|
||||
|
||||
* `prompt`: 提示词,描述画面中出现的内容。
|
||||
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。
|
||||
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 1,当设置为大于1的数值时生效。
|
||||
* `embedded_guidance`: FLUX-dev 的内嵌引导参数,默认值为 3.5。
|
||||
* `t5_sequence_length`: T5 模型的文本向量序列长度,默认值为 512。
|
||||
* `input_image`: 输入图像,用于图生图,该参数与 `denoising_strength` 配合使用。
|
||||
* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1,当数值接近 0 时,生成图像与输入图像相似;当数值接近 1 时,生成图像与输入图像相差更大。在不输入 `input_image` 参数时,请不要将其设置为非 1 的数值。
|
||||
* `height`: 图像高度,需保证高度为 16 的倍数。
|
||||
* `width`: 图像宽度,需保证宽度为 16 的倍数。
|
||||
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||
* `sigma_shift`: Rectified Flow 理论中的参数,默认为 3。数值越大,模型在去噪的开始阶段停留的步骤数越多,可适当调大这个参数来提高画面质量,但会因生成过程与训练过程不一致导致生成的图像内容与训练数据存在差异。
|
||||
* `num_inference_steps`: 推理次数,默认值为 30。
|
||||
* `kontext_images`: Kontext 模型的输入图像。
|
||||
* `controlnet_inputs`: ControlNet 模型的输入。
|
||||
* `ipadapter_images`: IP-Adapter 模型的输入图像。
|
||||
* `ipadapter_scale`: IP-Adapter 模型的控制强度。
|
||||
* `eligen_entity_prompts`: EliGen 模型的图像局部提示词。
|
||||
* `eligen_entity_masks`: EliGen 模型的局部提示词控制区域,与 `eligen_entity_prompts` 一一对应。
|
||||
* `eligen_enable_on_negative`: 是否在负向提示词一侧启用 EliGen,仅在 `cfg_scale > 1` 时生效。
|
||||
* `eligen_enable_inpaint`: 是否启用 EliGen 局部重绘。
|
||||
* `infinityou_id_image`: InfiniteYou 模型的人脸图像。
|
||||
* `infinityou_guidance`: InfiniteYou 模型的控制强度。
|
||||
* `flex_inpaint_image`: FLEX 模型用于局部重绘的图像。
|
||||
* `flex_inpaint_mask`: FLEX 模型用于局部重绘的区域。
|
||||
* `flex_control_image`: FLEX 模型用于结构控制的图像。
|
||||
* `flex_control_strength`: FLEX 模型用于结构控制的强度。
|
||||
* `flex_control_stop`: FLEX 模型结构控制的结束点,1表示全程启用,0.5表示在前半段启用,0表示不启用。
|
||||
* `step1x_reference_image`: Step1x-Edit 模型用于图像编辑的输入图像。
|
||||
* `lora_encoder_inputs`: LoRA 编码器的输入,格式为 ModelConfig 或本地路径。
|
||||
* `lora_encoder_scale`: LoRA 编码器的激活强度,默认值为1,数值越小,LoRA 激活越弱。
|
||||
* `tea_cache_l1_thresh`: TeaCache 的阈值,数值越大,速度越快,画面质量越差。请注意,开启 TeaCache 后推理速度并非均匀,因此进度条上显示的剩余时间将会变得不准确。
|
||||
* `tiled`: 是否启用 VAE 分块推理,默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。
|
||||
* `tile_size`: VAE 编解码阶段的分块大小,默认为 128,仅在 `tiled=True` 时生效。
|
||||
* `tile_stride`: VAE 编解码阶段的分块步长,默认为 64,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。
|
||||
* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## 模型训练
|
||||
|
||||
FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_training/train.py) 脚本进行。
|
||||
|
||||
<details>
|
||||
|
||||
<summary>脚本参数</summary>
|
||||
|
||||
脚本包含以下参数:
|
||||
|
||||
* 数据集
|
||||
* `--dataset_base_path`: 数据集的根路径。
|
||||
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||
* `--max_pixels`: 最大像素面积,默认为 1024*1024,当启用动态分辨率时,任何分辨率大于这个数值的图片都会被缩小。
|
||||
* `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* 模型
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。
|
||||
* 训练
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
|
||||
* 可训练模块
|
||||
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪一层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* 额外模型输入
|
||||
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
|
||||
* 显存管理
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 其他
|
||||
* `--align_to_opensource_format`: 是否将 FLUX DiT LoRA 的格式与开源版本对齐,仅对 LoRA 训练生效。
|
||||
|
||||
|
||||
此外,训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,在开始训练前运行 `accelerate config` 可配置 GPU 的相关参数。对于部分模型训练(例如模型的全量训练)脚本,我们提供了建议的 `accelerate` 配置文件,可在对应的训练脚本中查看。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 1: 准备数据集</summary>
|
||||
|
||||
数据集包含一系列文件,我们建议您这样组织数据集文件:
|
||||
|
||||
```
|
||||
data/example_image_dataset/
|
||||
├── metadata.csv
|
||||
├── image1.jpg
|
||||
└── image2.jpg
|
||||
```
|
||||
|
||||
其中 `image1.jpg`、`image2.jpg` 为训练用图像数据,`metadata.csv` 为元数据列表,例如
|
||||
|
||||
```
|
||||
image,prompt
|
||||
image1.jpg,"a cat is sleeping"
|
||||
image2.jpg,"a dog is running"
|
||||
```
|
||||
|
||||
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
数据集支持多种图片格式,`"jpg", "jpeg", "png", "webp"`。
|
||||
|
||||
图片的尺寸可通过脚本参数 `--height`、`--width` 控制。当 `--height` 和 `--width` 为空时将会开启动态分辨率,按照数据集中每个图像的实际宽高训练。
|
||||
|
||||
**我们强烈建议使用固定分辨率训练,因为在多卡训练中存在负载均衡问题。**
|
||||
|
||||
当模型需要额外输入时,例如具备控制能力的模型 [`black-forest-labs/FLUX.1-Kontext-dev`](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) 所需的 `kontext_images`,请在数据集中补充相应的列,例如:
|
||||
|
||||
```
|
||||
image,prompt,kontext_images
|
||||
image1.jpg,"a cat is sleeping",image1_reference.jpg
|
||||
```
|
||||
|
||||
额外输入若包含图像文件,则需要在 `--data_file_keys` 参数中指定要解析的列名。可根据额外输入增加相应的列名,例如 `--data_file_keys "image,kontext_images"`,同时启用 `--extra_inputs "kontext_images"`。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 2: 加载模型</summary>
|
||||
|
||||
类似于推理时的模型加载逻辑,可直接通过模型 ID 配置要加载的模型。例如,推理时我们通过以下设置加载模型
|
||||
|
||||
```python
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
]
|
||||
```
|
||||
|
||||
那么在训练时,填入以下参数即可加载对应的模型。
|
||||
|
||||
```shell
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors"
|
||||
```
|
||||
|
||||
如果您希望从本地文件加载模型,例如推理时
|
||||
|
||||
```python
|
||||
model_configs=[
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors"),
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors"),
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder_2/"),
|
||||
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/ae.safetensors"),
|
||||
]
|
||||
```
|
||||
|
||||
那么训练时需设置为
|
||||
|
||||
```shell
|
||||
--model_paths '[
|
||||
"models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors",
|
||||
"models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/black-forest-labs/FLUX.1-dev/text_encoder_2/",
|
||||
"models/black-forest-labs/FLUX.1-dev/ae.safetensors"
|
||||
]' \
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 3: 设置可训练模块</summary>
|
||||
|
||||
训练框架支持训练基础模型,或 LoRA 模型。以下是几个例子:
|
||||
|
||||
* 全量训练 DiT 部分:`--trainable_models dit`
|
||||
* 训练 DiT 部分的 LoRA 模型:`--lora_base_model dit --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" --lora_rank 32`
|
||||
|
||||
此外,由于训练脚本中加载了多个模块(text encoder、dit、vae),保存模型文件时需要移除前缀,例如在全量训练 DiT 部分或者训练 DiT 部分的 LoRA 模型时,请设置 `--remove_prefix_in_ckpt pipe.dit.`
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 4: 启动训练程序</summary>
|
||||
|
||||
我们为每一个模型编写了训练命令,请参考本文档开头的表格。
|
||||
|
||||
</details>
|
||||
@@ -1,24 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
|
||||
for tea_cache_l1_thresh in [None, 0.2, 0.4, 0.6, 0.8]:
|
||||
image = pipe(
|
||||
prompt=prompt, embedded_guidance=3.5, seed=0,
|
||||
num_inference_steps=50, tea_cache_l1_thresh=tea_cache_l1_thresh
|
||||
)
|
||||
image.save(f"image_{tea_cache_l1_thresh}.png")
|
||||
@@ -1,50 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
seed=0
|
||||
)
|
||||
image.save(f"image_1.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[200:400, 400:700] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save(f"image_mask.jpg")
|
||||
|
||||
inpaint_image = image
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
|
||||
seed=4
|
||||
)
|
||||
image.save(f"image_2_new.jpg")
|
||||
|
||||
control_image = Annotator("canny")(image)
|
||||
control_image.save("image_control.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_control_image=control_image,
|
||||
seed=4
|
||||
)
|
||||
image.save(f"image_3_new.jpg")
|
||||
@@ -1,54 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian long-haired female college student.",
|
||||
embedded_guidance=2.5,
|
||||
seed=1,
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="transform the style to anime style.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=2,
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
|
||||
image_3 = pipe(
|
||||
prompt="let her smile.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=3,
|
||||
)
|
||||
image_3.save("image_3.jpg")
|
||||
|
||||
image_4 = pipe(
|
||||
prompt="let the girl play basketball.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=4,
|
||||
)
|
||||
image_4.save("image_4.jpg")
|
||||
|
||||
image_5 = pipe(
|
||||
prompt="move the girl to a park, let her sit on a chair.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=5,
|
||||
)
|
||||
image_5.save("image_5.jpg")
|
||||
@@ -1,27 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-Krea-dev", origin_file_pattern="flux1-krea-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
prompt = "An beautiful woman is riding a bicycle in a park, wearing a red dress"
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
|
||||
image = pipe(prompt=prompt, seed=0, embedded_guidance=4.5)
|
||||
image.save("flux_krea.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
seed=0, cfg_scale=2, num_inference_steps=50,
|
||||
embedded_guidance=4.5
|
||||
)
|
||||
image.save("flux_krea_cfg.jpg")
|
||||
@@ -1,19 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
|
||||
],
|
||||
)
|
||||
|
||||
for i in [0.1, 0.3, 0.5, 0.7, 0.9]:
|
||||
image = pipe(prompt="a cat on the beach", seed=2, value_controller_inputs=[i])
|
||||
image.save(f"value_control_{i}.jpg")
|
||||
@@ -1,37 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a cat sitting on a chair",
|
||||
height=1024, width=1024,
|
||||
seed=8, rand_device="cuda",
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[100:350, 350: -300] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save("mask.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a cat sitting on a chair, wearing sunglasses",
|
||||
controlnet_inputs=[ControlNetInput(image=image_1, inpaint_mask=mask, scale=0.9)],
|
||||
height=1024, width=1024,
|
||||
seed=9, rand_device="cuda",
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
@@ -1,40 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
from diffsynth import download_models
|
||||
|
||||
|
||||
|
||||
download_models(["Annotators:Depth"])
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian girl, full body, red dress, summer",
|
||||
height=1024, width=1024,
|
||||
seed=6, rand_device="cuda",
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_canny = Annotator("canny")(image_1)
|
||||
image_depth = Annotator("depth")(image_1)
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a beautiful Asian girl, full body, red dress, winter",
|
||||
controlnet_inputs=[
|
||||
ControlNetInput(image=image_canny, scale=0.3, processor_id="canny"),
|
||||
ControlNetInput(image=image_depth, scale=0.3, processor_id="depth"),
|
||||
],
|
||||
height=1024, width=1024,
|
||||
seed=7, rand_device="cuda",
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
@@ -1,33 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a photo of a cat, highly detailed",
|
||||
height=768, width=768,
|
||||
seed=0, rand_device="cuda",
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_1 = image_1.resize((2048, 2048))
|
||||
image_2 = pipe(
|
||||
prompt="a photo of a cat, highly detailed",
|
||||
controlnet_inputs=[ControlNetInput(image=image_1, scale=0.7)],
|
||||
input_image=image_1,
|
||||
denoising_strength=0.99,
|
||||
height=2048, width=2048, tiled=True,
|
||||
seed=1, rand_device="cuda",
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
@@ -1,147 +0,0 @@
|
||||
import random
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from diffsynth import download_customized_models
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=3.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=50,
|
||||
embedded_guidance=3.5,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
download_from_modelscope = True
|
||||
if download_from_modelscope:
|
||||
model_id = "DiffSynth-Studio/Eligen"
|
||||
downloading_priority = ["ModelScope"]
|
||||
else:
|
||||
model_id = "modelscope/EliGen"
|
||||
downloading_priority = ["HuggingFace"]
|
||||
EliGen_path = download_customized_models(
|
||||
model_id=model_id,
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control",
|
||||
downloading_priority=downloading_priority)[0]
|
||||
pipe.load_lora(pipe.dit, EliGen_path, alpha=1)
|
||||
|
||||
# example 1
|
||||
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||
|
||||
# example 2
|
||||
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"]
|
||||
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||
|
||||
# example 3
|
||||
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||
|
||||
# example 4
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||
|
||||
# example 5
|
||||
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||
|
||||
# example 6
|
||||
global_prompt = "Snow White and the 6 Dwarfs."
|
||||
entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"]
|
||||
example(pipe, [8], 6, global_prompt, entity_prompts)
|
||||
|
||||
# example 7, same prompt with different seeds
|
||||
seeds = range(5, 9)
|
||||
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
@@ -1,24 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"),
|
||||
ModelConfig(model_id="google/siglip-so400m-patch14-384"),
|
||||
],
|
||||
)
|
||||
|
||||
origin_prompt = "a rabbit in a garden, colorful flowers"
|
||||
image = pipe(prompt=origin_prompt, height=1280, width=960, seed=42)
|
||||
image.save("style image.jpg")
|
||||
|
||||
image = pipe(prompt="A piggy", height=1280, width=960, seed=42,
|
||||
ipadapter_images=[image], ipadapter_scale=0.7)
|
||||
image.save("A piggy.jpg")
|
||||
@@ -1,59 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from modelscope import dataset_snapshot_download
|
||||
from modelscope import snapshot_download
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
snapshot_download(
|
||||
"ByteDance/InfiniteYou",
|
||||
allow_file_pattern="supports/insightface/models/antelopev2/*",
|
||||
local_dir="models/ByteDance/InfiniteYou",
|
||||
)
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin"),
|
||||
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/infiniteyou/*",
|
||||
)
|
||||
|
||||
height, width = 1024, 1024
|
||||
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
|
||||
controlnet_inputs = [ControlNetInput(image=controlnet_image, scale=1.0, processor_id="None")]
|
||||
|
||||
prompt = "A man, portrait, cinematic"
|
||||
id_image = "data/examples/infiniteyou/man.jpg"
|
||||
id_image = Image.open(id_image).convert('RGB')
|
||||
image = pipe(
|
||||
prompt=prompt, seed=1,
|
||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||
controlnet_inputs=controlnet_inputs,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
height=height, width=width,
|
||||
)
|
||||
image.save("man.jpg")
|
||||
|
||||
prompt = "A woman, portrait, cinematic"
|
||||
id_image = "data/examples/infiniteyou/woman.jpg"
|
||||
id_image = Image.open(id_image).convert('RGB')
|
||||
image = pipe(
|
||||
prompt=prompt, seed=1,
|
||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||
controlnet_inputs=controlnet_inputs,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
height=height, width=width,
|
||||
)
|
||||
image.save("woman.jpg")
|
||||
@@ -1,40 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.enable_lora_magic()
|
||||
|
||||
lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors")
|
||||
pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA.
|
||||
|
||||
# Empty prompt can automatically activate LoRA capabilities.
|
||||
image = pipe(prompt="", seed=0, lora_encoder_inputs=lora)
|
||||
image.save("image_1.jpg")
|
||||
|
||||
image = pipe(prompt="", seed=0)
|
||||
image.save("image_1_origin.jpg")
|
||||
|
||||
# Prompt without trigger words can also activate LoRA capabilities.
|
||||
image = pipe(prompt="a car", seed=0, lora_encoder_inputs=lora)
|
||||
image.save("image_2.jpg")
|
||||
|
||||
image = pipe(prompt="a car", seed=0,)
|
||||
image.save("image_2_origin.jpg")
|
||||
|
||||
# Adjust the activation intensity through the scale parameter.
|
||||
image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0)
|
||||
image.save("image_3.jpg")
|
||||
|
||||
image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5)
|
||||
image.save("image_3_scale.jpg")
|
||||
@@ -1,29 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.enable_lora_magic()
|
||||
|
||||
pipe.load_lora(
|
||||
pipe.dit,
|
||||
ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"),
|
||||
hotload=True,
|
||||
)
|
||||
pipe.load_lora(
|
||||
pipe.dit,
|
||||
ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"),
|
||||
hotload=True,
|
||||
)
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image_fused.jpg")
|
||||
@@ -1,26 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
|
||||
image = pipe(prompt=prompt, seed=0)
|
||||
image.save("flux.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
seed=0, cfg_scale=2, num_inference_steps=50,
|
||||
)
|
||||
image.save("flux_cfg.jpg")
|
||||
@@ -1,37 +0,0 @@
|
||||
import importlib
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
if importlib.util.find_spec("transformers") is None:
|
||||
raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.")
|
||||
else:
|
||||
import transformers
|
||||
assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`."
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
nexus_gen_processor_config=ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg")
|
||||
ref_image = Image.open("data/examples/nexusgen/cat.jpg").convert("RGB")
|
||||
prompt = "Add a crown."
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt="",
|
||||
seed=42, cfg_scale=2.0, num_inference_steps=50,
|
||||
nexus_gen_reference_image=ref_image,
|
||||
height=512, width=512,
|
||||
)
|
||||
image.save("cat_crown.jpg")
|
||||
@@ -1,32 +0,0 @@
|
||||
import importlib
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
if importlib.util.find_spec("transformers") is None:
|
||||
raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.")
|
||||
else:
|
||||
import transformers
|
||||
assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`."
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"),
|
||||
)
|
||||
|
||||
prompt = "一只可爱的猫咪"
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt="",
|
||||
seed=0, cfg_scale=3, num_inference_steps=50,
|
||||
height=1024, width=1024,
|
||||
)
|
||||
image.save("cat.jpg")
|
||||
@@ -1,32 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct"),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)
|
||||
image = pipe(
|
||||
prompt="draw red flowers in Chinese ink painting style",
|
||||
step1x_reference_image=image,
|
||||
width=832, height=1248, cfg_scale=6,
|
||||
seed=1, rand_device='cuda'
|
||||
)
|
||||
image.save("image_1.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt="add more flowers in Chinese ink painting style",
|
||||
step1x_reference_image=image,
|
||||
width=832, height=1248, cfg_scale=6,
|
||||
seed=2, rand_device='cuda'
|
||||
)
|
||||
image.save("image_2.jpg")
|
||||
@@ -1,51 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
seed=0
|
||||
)
|
||||
image.save(f"image_1.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[200:400, 400:700] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save(f"image_mask.jpg")
|
||||
|
||||
inpaint_image = image
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
|
||||
seed=4
|
||||
)
|
||||
image.save(f"image_2_new.jpg")
|
||||
|
||||
control_image = Annotator("canny")(image)
|
||||
control_image.save("image_control.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_control_image=control_image,
|
||||
seed=4
|
||||
)
|
||||
image.save(f"image_3_new.jpg")
|
||||
@@ -1,55 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian long-haired female college student.",
|
||||
embedded_guidance=2.5,
|
||||
seed=1,
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="transform the style to anime style.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=2,
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
|
||||
image_3 = pipe(
|
||||
prompt="let her smile.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=3,
|
||||
)
|
||||
image_3.save("image_3.jpg")
|
||||
|
||||
image_4 = pipe(
|
||||
prompt="let the girl play basketball.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=4,
|
||||
)
|
||||
image_4.save("image_4.jpg")
|
||||
|
||||
image_5 = pipe(
|
||||
prompt="move the girl to a park, let her sit on a chair.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=5,
|
||||
)
|
||||
image_5.save("image_5.jpg")
|
||||
@@ -1,28 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-Krea-dev", origin_file_pattern="flux1-krea-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
prompt = "An beautiful woman is riding a bicycle in a park, wearing a red dress"
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
|
||||
image = pipe(prompt=prompt, seed=0, embedded_guidance=4.5)
|
||||
image.save("flux_krea.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
seed=0, cfg_scale=2, num_inference_steps=50,
|
||||
embedded_guidance=4.5
|
||||
)
|
||||
image.save("flux_krea_cfg.jpg")
|
||||
@@ -1,20 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn)
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
for i in [0.1, 0.3, 0.5, 0.7, 0.9]:
|
||||
image = pipe(prompt="a cat on the beach", seed=2, value_controller_inputs=[i])
|
||||
image.save(f"value_control_{i}.jpg")
|
||||
@@ -1,38 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a cat sitting on a chair",
|
||||
height=1024, width=1024,
|
||||
seed=8, rand_device="cuda",
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[100:350, 350: -300] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save("mask.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a cat sitting on a chair, wearing sunglasses",
|
||||
controlnet_inputs=[ControlNetInput(image=image_1, inpaint_mask=mask, scale=0.9)],
|
||||
height=1024, width=1024,
|
||||
seed=9, rand_device="cuda",
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
@@ -1,41 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
from diffsynth import download_models
|
||||
|
||||
|
||||
|
||||
download_models(["Annotators:Depth"])
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian girl, full body, red dress, summer",
|
||||
height=1024, width=1024,
|
||||
seed=6, rand_device="cuda",
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_canny = Annotator("canny")(image_1)
|
||||
image_depth = Annotator("depth")(image_1)
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a beautiful Asian girl, full body, red dress, winter",
|
||||
controlnet_inputs=[
|
||||
ControlNetInput(image=image_canny, scale=0.3, processor_id="canny"),
|
||||
ControlNetInput(image=image_depth, scale=0.3, processor_id="depth"),
|
||||
],
|
||||
height=1024, width=1024,
|
||||
seed=7, rand_device="cuda",
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a photo of a cat, highly detailed",
|
||||
height=768, width=768,
|
||||
seed=0, rand_device="cuda",
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_1 = image_1.resize((2048, 2048))
|
||||
image_2 = pipe(
|
||||
prompt="a photo of a cat, highly detailed",
|
||||
controlnet_inputs=[ControlNetInput(image=image_1, scale=0.7)],
|
||||
input_image=image_1,
|
||||
denoising_strength=0.99,
|
||||
height=2048, width=2048, tiled=True,
|
||||
seed=1, rand_device="cuda",
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
@@ -1,148 +0,0 @@
|
||||
import random
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from diffsynth import download_customized_models
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=3.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=50,
|
||||
embedded_guidance=3.5,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
download_from_modelscope = True
|
||||
if download_from_modelscope:
|
||||
model_id = "DiffSynth-Studio/Eligen"
|
||||
downloading_priority = ["ModelScope"]
|
||||
else:
|
||||
model_id = "modelscope/EliGen"
|
||||
downloading_priority = ["HuggingFace"]
|
||||
EliGen_path = download_customized_models(
|
||||
model_id=model_id,
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control",
|
||||
downloading_priority=downloading_priority)[0]
|
||||
pipe.load_lora(pipe.dit, EliGen_path, alpha=1)
|
||||
|
||||
# example 1
|
||||
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||
|
||||
# example 2
|
||||
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"]
|
||||
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||
|
||||
# example 3
|
||||
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||
|
||||
# example 4
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||
|
||||
# example 5
|
||||
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||
|
||||
# example 6
|
||||
global_prompt = "Snow White and the 6 Dwarfs."
|
||||
entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"]
|
||||
example(pipe, [8], 6, global_prompt, entity_prompts)
|
||||
|
||||
# example 7, same prompt with different seeds
|
||||
seeds = range(5, 9)
|
||||
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
@@ -1,25 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="google/siglip-so400m-patch14-384", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
origin_prompt = "a rabbit in a garden, colorful flowers"
|
||||
image = pipe(prompt=origin_prompt, height=1280, width=960, seed=42)
|
||||
image.save("style image.jpg")
|
||||
|
||||
image = pipe(prompt="A piggy", height=1280, width=960, seed=42,
|
||||
ipadapter_images=[image], ipadapter_scale=0.7)
|
||||
image.save("A piggy.jpg")
|
||||
@@ -1,60 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from modelscope import dataset_snapshot_download
|
||||
from modelscope import snapshot_download
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
snapshot_download(
|
||||
"ByteDance/InfiniteYou",
|
||||
allow_file_pattern="supports/insightface/models/antelopev2/*",
|
||||
local_dir="models/ByteDance/InfiniteYou",
|
||||
)
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/infiniteyou/*",
|
||||
)
|
||||
|
||||
height, width = 1024, 1024
|
||||
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
|
||||
controlnet_inputs = [ControlNetInput(image=controlnet_image, scale=1.0, processor_id="None")]
|
||||
|
||||
prompt = "A man, portrait, cinematic"
|
||||
id_image = "data/examples/infiniteyou/man.jpg"
|
||||
id_image = Image.open(id_image).convert('RGB')
|
||||
image = pipe(
|
||||
prompt=prompt, seed=1,
|
||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||
controlnet_inputs=controlnet_inputs,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
height=height, width=width,
|
||||
)
|
||||
image.save("man.jpg")
|
||||
|
||||
prompt = "A woman, portrait, cinematic"
|
||||
id_image = "data/examples/infiniteyou/woman.jpg"
|
||||
id_image = Image.open(id_image).convert('RGB')
|
||||
image = pipe(
|
||||
prompt=prompt, seed=1,
|
||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||
controlnet_inputs=controlnet_inputs,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
height=height, width=width,
|
||||
)
|
||||
image.save("woman.jpg")
|
||||
@@ -1,41 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
pipe.enable_lora_magic()
|
||||
|
||||
lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors")
|
||||
pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA.
|
||||
|
||||
# Empty prompt can automatically activate LoRA capabilities.
|
||||
image = pipe(prompt="", seed=0, lora_encoder_inputs=lora)
|
||||
image.save("image_1.jpg")
|
||||
|
||||
image = pipe(prompt="", seed=0)
|
||||
image.save("image_1_origin.jpg")
|
||||
|
||||
# Prompt without trigger words can also activate LoRA capabilities.
|
||||
image = pipe(prompt="a car", seed=0, lora_encoder_inputs=lora)
|
||||
image.save("image_2.jpg")
|
||||
|
||||
image = pipe(prompt="a car", seed=0,)
|
||||
image.save("image_2_origin.jpg")
|
||||
|
||||
# Adjust the activation intensity through the scale parameter.
|
||||
image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0)
|
||||
image.save("image_3.jpg")
|
||||
|
||||
image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5)
|
||||
image.save("image_3_scale.jpg")
|
||||
@@ -1,35 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="DiffSynth-Studio/FLUX.1-dev-LoRAFusion", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn)
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
pipe.enable_lora_patcher()
|
||||
pipe.load_lora(
|
||||
pipe.dit,
|
||||
ModelConfig(model_id="yangyufeng/fgao", origin_file_pattern="30.safetensors"),
|
||||
hotload=True
|
||||
)
|
||||
pipe.load_lora(
|
||||
pipe.dit,
|
||||
ModelConfig(model_id="bobooblue/LoRA-bling-mai", origin_file_pattern="10.safetensors"),
|
||||
hotload=True
|
||||
)
|
||||
pipe.load_lora(
|
||||
pipe.dit,
|
||||
ModelConfig(model_id="JIETANGAB/E", origin_file_pattern="17.safetensors"),
|
||||
hotload=True
|
||||
)
|
||||
|
||||
image = pipe(prompt="This is a digital painting in a soft, ethereal style. a beautiful Asian girl Shine like a diamond. Everywhere is shining with bling bling luster.The background is a textured blue with visible brushstrokes, giving the image an impressionistic style reminiscent of Vincent van Gogh's work", seed=0)
|
||||
image.save("flux.jpg")
|
||||
@@ -1,27 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
|
||||
image = pipe(prompt=prompt, seed=0)
|
||||
image.save("flux.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
seed=0, cfg_scale=2, num_inference_steps=50,
|
||||
)
|
||||
image.save("flux_cfg.jpg")
|
||||
@@ -1,38 +0,0 @@
|
||||
import importlib
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
if importlib.util.find_spec("transformers") is None:
|
||||
raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.")
|
||||
else:
|
||||
import transformers
|
||||
assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`."
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
|
||||
],
|
||||
nexus_gen_processor_config=ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg")
|
||||
ref_image = Image.open("data/examples/nexusgen/cat.jpg").convert("RGB")
|
||||
prompt = "Add a crown."
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt="",
|
||||
seed=42, cfg_scale=2.0, num_inference_steps=50,
|
||||
nexus_gen_reference_image=ref_image,
|
||||
height=512, width=512,
|
||||
)
|
||||
image.save("cat_crown.jpg")
|
||||
@@ -1,33 +0,0 @@
|
||||
import importlib
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
|
||||
|
||||
if importlib.util.find_spec("transformers") is None:
|
||||
raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.")
|
||||
else:
|
||||
import transformers
|
||||
assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`."
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
|
||||
],
|
||||
nexus_gen_processor_config=ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
prompt = "一只可爱的猫咪"
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt="",
|
||||
seed=0, cfg_scale=3, num_inference_steps=50,
|
||||
height=1024, width=1024,
|
||||
)
|
||||
image.save("cat.jpg")
|
||||
@@ -1,33 +0,0 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)
|
||||
image = pipe(
|
||||
prompt="draw red flowers in Chinese ink painting style",
|
||||
step1x_reference_image=image,
|
||||
width=832, height=1248, cfg_scale=6,
|
||||
seed=1, rand_device='cuda'
|
||||
)
|
||||
image.save("image_1.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt="add more flowers in Chinese ink painting style",
|
||||
step1x_reference_image=image,
|
||||
width=832, height=1248, cfg_scale=6,
|
||||
seed=2, rand_device='cuda'
|
||||
)
|
||||
image.save("image_2.jpg")
|
||||
@@ -1,12 +0,0 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 200 \
|
||||
--model_id_with_origin_paths "ostris/Flex.2-preview:Flex.2-preview.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLEX.2-preview_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,14 +0,0 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \
|
||||
--data_file_keys "image,kontext_images" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-Kontext-dev_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "kontext_images" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,12 +0,0 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Krea-dev:flux1-krea-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-Krea-dev_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,14 +0,0 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_attrictrl.csv \
|
||||
--data_file_keys "image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,DiffSynth-Studio/AttriCtrl-FLUX.1-Dev:models/brightness.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.value_controller.encoders.0." \
|
||||
--output_path "./models/train/FLUX.1-dev-AttriCtrl_full" \
|
||||
--trainable_models "value_controller" \
|
||||
--extra_inputs "value_controller_inputs" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,14 +0,0 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_inpaint.csv \
|
||||
--data_file_keys "image,controlnet_image,controlnet_inpaint_mask" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta:diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
|
||||
--output_path "./models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_full" \
|
||||
--trainable_models "controlnet" \
|
||||
--extra_inputs "controlnet_image,controlnet_inpaint_mask" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,14 +0,0 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
|
||||
--data_file_keys "image,controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-Controlnet-Union-alpha:diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
|
||||
--output_path "./models/train/FLUX.1-dev-Controlnet-Union-alpha_full" \
|
||||
--trainable_models "controlnet" \
|
||||
--extra_inputs "controlnet_image,controlnet_processor_id" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,14 +0,0 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
|
||||
--data_file_keys "image,controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,jasperai/Flux.1-dev-Controlnet-Upscaler:diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
|
||||
--output_path "./models/train/FLUX.1-dev-Controlnet-Upscaler_full" \
|
||||
--trainable_models "controlnet" \
|
||||
--extra_inputs "controlnet_image" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,14 +0,0 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_ipadapter.csv \
|
||||
--data_file_keys "image,ipadapter_images" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.ipadapter." \
|
||||
--output_path "./models/train/FLUX.1-dev-IP-Adapter_full" \
|
||||
--trainable_models "ipadapter" \
|
||||
--extra_inputs "ipadapter_images" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,14 +0,0 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_infiniteyou.csv \
|
||||
--data_file_keys "image,controlnet_image,infinityou_id_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/image_proj_model.bin,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe." \
|
||||
--output_path "./models/train/FLUX.1-dev-InfiniteYou_full" \
|
||||
--trainable_models "controlnet,image_proj_model" \
|
||||
--extra_inputs "controlnet_image,infinityou_id_image,infinityou_guidance" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,14 +0,0 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_lora_encoder.csv \
|
||||
--data_file_keys "image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev:model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.lora_encoder." \
|
||||
--output_path "./models/train/FLUX.1-dev-LoRA-Encoder_full" \
|
||||
--trainable_models "lora_encoder" \
|
||||
--extra_inputs "lora_encoder_inputs" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,12 +0,0 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-dev_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,14 +0,0 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_nexusgen_edit.csv \
|
||||
--data_file_keys "image,nexus_gen_reference_image" \
|
||||
--max_pixels 262144 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "DiffSynth-Studio/Nexus-GenV2:model*.safetensors,DiffSynth-Studio/Nexus-GenV2:edit_decoder.bin,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-NexusGen-Edit_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "nexus_gen_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
@@ -1,14 +0,0 @@
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_step1x.csv \
|
||||
--data_file_keys "image,step1x_reference_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Step1X-Edit_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "step1x_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
@@ -1,22 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -1,22 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
offload_optimizer_device: 'cpu'
|
||||
offload_param_device: 'cpu'
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -1,15 +0,0 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "ostris/Flex.2-preview:Flex.2-preview.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLEX.2-preview_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,17 +0,0 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \
|
||||
--data_file_keys "image,kontext_images" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-Kontext-dev_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--extra_inputs "kontext_images" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,15 +0,0 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Krea-dev:flux1-krea-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-Krea-dev_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user