mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 23:26:15 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5934719f8 | ||
|
|
54345f8678 | ||
|
|
2d7d5137ea | ||
|
|
3799bdc23a | ||
|
|
5cdab9ed01 | ||
|
|
a8a0f082bb | ||
|
|
9453700a30 | ||
|
|
82e482286c |
125
README.md
125
README.md
@@ -34,6 +34,8 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|||||||
|
|
||||||
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
||||||
|
|
||||||
|
- **April 24, 2026** We add support for Stable Diffusion v1.5 and SDXL, including inference, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/Stable-Diffusion.md), [documentation](/docs/en/Model_Details/Stable-Diffusion-XL.md) and [example code](/examples/stable_diffusion/).
|
||||||
|
|
||||||
- **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/).
|
- **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/).
|
||||||
|
|
||||||
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
|
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
|
||||||
@@ -299,6 +301,129 @@ Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### Stable Diffusion: [/docs/en/Model_Details/Stable-Diffusion.md](/docs/en/Model_Details/Stable-Diffusion.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Quick Start</summary>
|
||||||
|
|
||||||
|
Running the following code will quickly load the [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 2GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Examples</summary>
|
||||||
|
|
||||||
|
Example code for Stable Diffusion is available at: [/examples/stable_diffusion/](/examples/stable_diffusion/)
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### Stable Diffusion XL: [/docs/en/Model_Details/Stable-Diffusion-XL.md](/docs/en/Model_Details/Stable-Diffusion-XL.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Quick Start</summary>
|
||||||
|
|
||||||
|
Running the following code will quickly load the [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 6GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Examples</summary>
|
||||||
|
|
||||||
|
Example code for Stable Diffusion XL is available at: [/examples/stable_diffusion_xl/](/examples/stable_diffusion_xl/)
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
#### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md)
|
#### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md)
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|||||||
125
README_zh.md
125
README_zh.md
@@ -34,6 +34,8 @@ DiffSynth 目前包括两个开源项目:
|
|||||||
|
|
||||||
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||||
|
|
||||||
|
- **2026年4月24日** 我们新增对 Stable Diffusion v1.5 和 SDXL 的支持,包括推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/Stable-Diffusion.md)和[示例代码](/examples/stable_diffusion/)。
|
||||||
|
|
||||||
- **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。
|
- **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。
|
||||||
|
|
||||||
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
|
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
|
||||||
@@ -299,6 +301,129 @@ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### Stable Diffusion:[/docs/zh/Model_Details/Stable-Diffusion.md](/docs/zh/Model_Details/Stable-Diffusion.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) 模型并进行推理。显存管理已启用,框架会自动根据剩余显存控制模型参数的加载,最低 2GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Stable Diffusion 的示例代码位于:[/examples/stable_diffusion/](/examples/stable_diffusion/)
|
||||||
|
|
||||||
|
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### Stable Diffusion XL:[/docs/zh/Model_Details/Stable-Diffusion-XL.md](/docs/zh/Model_Details/Stable-Diffusion-XL.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) 模型并进行推理。显存管理已启用,框架会自动根据剩余显存控制模型参数的加载,最低 6GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Stable Diffusion XL 的示例代码位于:[/examples/stable_diffusion_xl/](/examples/stable_diffusion_xl/)
|
||||||
|
|
||||||
|
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
|
#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ qwen_image_series = [
|
|||||||
"model_hash": "5722b5c873720009de96422993b15682",
|
"model_hash": "5722b5c873720009de96422993b15682",
|
||||||
"model_name": "dinov3_image_encoder",
|
"model_name": "dinov3_image_encoder",
|
||||||
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
|
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.dino_v3.DINOv3StateDictConverter",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
# Example:
|
# Example:
|
||||||
@@ -901,6 +900,61 @@ mova_series = [
|
|||||||
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
|
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
stable_diffusion_xl_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
|
||||||
|
"model_name": "stable_diffusion_xl_unet",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
|
||||||
|
"extra_kwargs": {"attention_head_dim": [5, 10, 20], "transformer_layers_per_block": [1, 2, 10], "use_linear_projection": True, "addition_embed_type": "text_time", "addition_time_embed_dim": 256, "projection_class_embeddings_input_dim": 2816},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
|
||||||
|
"model_hash": "98cc34ccc5b54ae0e56bdea8688dcd5a",
|
||||||
|
"model_name": "stable_diffusion_xl_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_xl_text_encoder.SDXLTextEncoder2StateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors")
|
||||||
|
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
||||||
|
"model_name": "stable_diffusion_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "13115dd45a6e1c39860f91ab073b8a78",
|
||||||
|
"model_name": "stable_diffusion_xl_vae",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter",
|
||||||
|
"extra_kwargs": {"scaling_factor": 0.13025, "sample_size": 1024, "force_upcast": True},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
stable_diffusion_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors")
|
||||||
|
"model_hash": "ffd1737ae9df7fd43f5fbed653bdad67",
|
||||||
|
"model_name": "stable_diffusion_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "f86d5683ed32433be8ca69969c67ba69",
|
||||||
|
"model_name": "stable_diffusion_vae",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "025a4b86a84829399d89f613e580757b",
|
||||||
|
"model_name": "stable_diffusion_unet",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_unet.UNet2DConditionModel",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
joyai_image_series = [
|
joyai_image_series = [
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth")
|
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth")
|
||||||
@@ -917,4 +971,4 @@ joyai_image_series = [
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series
|
MODEL_CONFIGS = stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series
|
||||||
|
|||||||
@@ -295,6 +295,45 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|||||||
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
},
|
},
|
||||||
|
"diffsynth.models.stable_diffusion_unet.UNet2DConditionModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.stable_diffusion_vae.StableDiffusionVAE": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.stable_diffusion_vae.Upsample2D": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.stable_diffusion_vae.Downsample2D": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def QwenImageTextEncoder_Module_Map_Updater():
|
def QwenImageTextEncoder_Module_Map_Updater():
|
||||||
|
|||||||
107
diffsynth/diffusion/ddim_scheduler.py
Normal file
107
diffsynth/diffusion/ddim_scheduler.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import torch, math
|
||||||
|
|
||||||
|
|
||||||
|
class DDIMScheduler():
|
||||||
|
|
||||||
|
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
|
||||||
|
self.num_train_timesteps = num_train_timesteps
|
||||||
|
if beta_schedule == "scaled_linear":
|
||||||
|
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
|
||||||
|
elif beta_schedule == "linear":
|
||||||
|
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||||
|
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
|
||||||
|
if rescale_zero_terminal_snr:
|
||||||
|
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
|
||||||
|
self.alphas_cumprod = self.alphas_cumprod.tolist()
|
||||||
|
self.set_timesteps(10)
|
||||||
|
self.prediction_type = prediction_type
|
||||||
|
self.training = False
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_zero_terminal_snr(self, alphas_cumprod):
|
||||||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
|
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Convert alphas_bar_sqrt to betas
|
||||||
|
alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
|
||||||
|
|
||||||
|
return alphas_bar
|
||||||
|
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps, denoising_strength=1.0, training=False, **kwargs):
|
||||||
|
# The timesteps are aligned to 999...0, which is different from other implementations,
|
||||||
|
# but I think this implementation is more reasonable in theory.
|
||||||
|
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
||||||
|
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
||||||
|
if num_inference_steps == 1:
|
||||||
|
self.timesteps = torch.Tensor([max_timestep])
|
||||||
|
else:
|
||||||
|
step_length = max_timestep / (num_inference_steps - 1)
|
||||||
|
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
|
||||||
|
self.training = training
|
||||||
|
|
||||||
|
|
||||||
|
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
||||||
|
if self.prediction_type == "epsilon":
|
||||||
|
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
|
||||||
|
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
|
||||||
|
prev_sample = sample * weight_x + model_output * weight_e
|
||||||
|
elif self.prediction_type == "v_prediction":
|
||||||
|
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
|
||||||
|
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
|
||||||
|
prev_sample = sample * weight_x + model_output * weight_e
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{self.prediction_type} is not implemented")
|
||||||
|
return prev_sample
|
||||||
|
|
||||||
|
|
||||||
|
def step(self, model_output, timestep, sample, to_final=False):
|
||||||
|
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||||
|
alpha_prod_t_prev = 1.0
|
||||||
|
else:
|
||||||
|
timestep_prev = int(self.timesteps[timestep_id + 1])
|
||||||
|
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
|
||||||
|
|
||||||
|
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
|
||||||
|
|
||||||
|
|
||||||
|
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||||
|
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
||||||
|
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
|
||||||
|
return noise_pred
|
||||||
|
|
||||||
|
|
||||||
|
def add_noise(self, original_samples, noise, timestep):
|
||||||
|
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||||
|
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||||
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||||
|
return noisy_samples
|
||||||
|
|
||||||
|
|
||||||
|
def training_target(self, sample, noise, timestep):
|
||||||
|
if self.prediction_type == "epsilon":
|
||||||
|
return noise
|
||||||
|
else:
|
||||||
|
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||||
|
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||||
|
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
|
def training_weight(self, timestep):
|
||||||
|
return 1.0
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTModel, DINOv3ViTConfig
|
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
|
||||||
from transformers import DINOv3ViTImageProcessor
|
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..core.device.npu_compatible_device import get_device_type
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
@@ -40,7 +40,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
|
|||||||
value_bias = False
|
value_bias = False
|
||||||
)
|
)
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.processor = DINOv3ViTImageProcessor(
|
self.processor = DINOv3ViTImageProcessorFast(
|
||||||
crop_size = None,
|
crop_size = None,
|
||||||
data_format = "channels_first",
|
data_format = "channels_first",
|
||||||
default_to_square = True,
|
default_to_square = True,
|
||||||
@@ -56,7 +56,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
|
|||||||
0.456,
|
0.456,
|
||||||
0.406
|
0.406
|
||||||
],
|
],
|
||||||
image_processor_type = "DINOv3ViTImageProcessor",
|
image_processor_type = "DINOv3ViTImageProcessorFast",
|
||||||
image_std = [
|
image_std = [
|
||||||
0.229,
|
0.229,
|
||||||
0.224,
|
0.224,
|
||||||
@@ -82,7 +82,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
|
|||||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
position_embeddings = self.rope_embeddings(pixel_values)
|
position_embeddings = self.rope_embeddings(pixel_values)
|
||||||
|
|
||||||
for i, layer_module in enumerate(self.model.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
hidden_states = layer_module(
|
hidden_states = layer_module(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from transformers.models.siglip.modeling_siglip import SiglipVisionModel, SiglipVisionConfig
|
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
|
||||||
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessor
|
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffsynth.core.device.npu_compatible_device import get_device_type
|
from diffsynth.core.device.npu_compatible_device import get_device_type
|
||||||
|
|
||||||
|
|
||||||
class Siglip2ImageEncoder(SiglipVisionModel):
|
class Siglip2ImageEncoder(SiglipVisionTransformer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
config = SiglipVisionConfig(
|
config = SiglipVisionConfig(
|
||||||
attention_dropout = 0.0,
|
attention_dropout = 0.0,
|
||||||
@@ -90,7 +90,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
|
|||||||
transformers_version = "4.57.1"
|
transformers_version = "4.57.1"
|
||||||
)
|
)
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.processor = Siglip2ImageProcessor(
|
self.processor = Siglip2ImageProcessorFast(
|
||||||
**{
|
**{
|
||||||
"data_format": "channels_first",
|
"data_format": "channels_first",
|
||||||
"default_to_square": True,
|
"default_to_square": True,
|
||||||
@@ -106,7 +106,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
|
|||||||
0.5,
|
0.5,
|
||||||
0.5
|
0.5
|
||||||
],
|
],
|
||||||
"image_processor_type": "Siglip2ImageProcessor",
|
"image_processor_type": "Siglip2ImageProcessorFast",
|
||||||
"image_std": [
|
"image_std": [
|
||||||
0.5,
|
0.5,
|
||||||
0.5,
|
0.5,
|
||||||
|
|||||||
78
diffsynth/models/stable_diffusion_text_encoder.py
Normal file
78
diffsynth/models/stable_diffusion_text_encoder.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SDTextEncoder(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=768,
|
||||||
|
intermediate_size=3072,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
max_position_embeddings=77,
|
||||||
|
vocab_size=49408,
|
||||||
|
layer_norm_eps=1e-05,
|
||||||
|
hidden_act="quick_gelu",
|
||||||
|
initializer_factor=1.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=1,
|
||||||
|
projection_dim=768,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import CLIPConfig, CLIPTextModel
|
||||||
|
|
||||||
|
config = CLIPConfig(
|
||||||
|
text_config={
|
||||||
|
"hidden_size": hidden_size,
|
||||||
|
"intermediate_size": intermediate_size,
|
||||||
|
"num_hidden_layers": num_hidden_layers,
|
||||||
|
"num_attention_heads": num_attention_heads,
|
||||||
|
"max_position_embeddings": max_position_embeddings,
|
||||||
|
"vocab_size": vocab_size,
|
||||||
|
"layer_norm_eps": layer_norm_eps,
|
||||||
|
"hidden_act": hidden_act,
|
||||||
|
"initializer_factor": initializer_factor,
|
||||||
|
"initializer_range": initializer_range,
|
||||||
|
"bos_token_id": bos_token_id,
|
||||||
|
"eos_token_id": eos_token_id,
|
||||||
|
"pad_token_id": pad_token_id,
|
||||||
|
"projection_dim": projection_dim,
|
||||||
|
"dropout": 0.0,
|
||||||
|
},
|
||||||
|
vision_config={
|
||||||
|
"hidden_size": hidden_size,
|
||||||
|
"intermediate_size": intermediate_size,
|
||||||
|
"num_hidden_layers": num_hidden_layers,
|
||||||
|
"num_attention_heads": num_attention_heads,
|
||||||
|
"max_position_embeddings": max_position_embeddings,
|
||||||
|
"layer_norm_eps": layer_norm_eps,
|
||||||
|
"hidden_act": hidden_act,
|
||||||
|
"initializer_factor": initializer_factor,
|
||||||
|
"initializer_range": initializer_range,
|
||||||
|
"projection_dim": projection_dim,
|
||||||
|
},
|
||||||
|
projection_dim=projection_dim,
|
||||||
|
)
|
||||||
|
self.model = CLIPTextModel(config.text_config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
output_hidden_states=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if output_hidden_states:
|
||||||
|
return outputs.last_hidden_state, outputs.hidden_states
|
||||||
|
return outputs.last_hidden_state
|
||||||
912
diffsynth/models/stable_diffusion_unet.py
Normal file
912
diffsynth/models/stable_diffusion_unet.py
Normal file
@@ -0,0 +1,912 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Time Embedding =====
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
|
self.freq_shift = freq_shift
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
half_dim = self.num_channels // 2
|
||||||
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||||
|
exponent = exponent / half_dim + self.freq_shift
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
sin_emb = torch.sin(emb)
|
||||||
|
cos_emb = torch.cos(emb)
|
||||||
|
if self.flip_sin_to_cos:
|
||||||
|
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||||
|
else:
|
||||||
|
emb = torch.cat([sin_emb, cos_emb], dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||||
|
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
|
||||||
|
out_dim = out_dim if out_dim is not None else time_embed_dim
|
||||||
|
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
|
||||||
|
|
||||||
|
def forward(self, sample):
|
||||||
|
sample = self.linear_1(sample)
|
||||||
|
sample = self.act(sample)
|
||||||
|
sample = self.linear_2(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
# ===== ResNet Blocks =====
|
||||||
|
|
||||||
|
class ResnetBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout=0.0,
|
||||||
|
temb_channels=512,
|
||||||
|
groups=32,
|
||||||
|
groups_out=None,
|
||||||
|
pre_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
non_linearity="swish",
|
||||||
|
time_embedding_norm="default",
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
use_in_shortcut=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.pre_norm = pre_norm
|
||||||
|
self.time_embedding_norm = time_embedding_norm
|
||||||
|
self.output_scale_factor = output_scale_factor
|
||||||
|
|
||||||
|
if groups_out is None:
|
||||||
|
groups_out = groups
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
if temb_channels is not None:
|
||||||
|
if self.time_embedding_norm == "default":
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
|
||||||
|
elif self.time_embedding_norm == "scale_shift":
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
|
||||||
|
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
if non_linearity == "swish":
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
elif non_linearity == "silu":
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
elif non_linearity == "gelu":
|
||||||
|
self.nonlinearity = nn.GELU()
|
||||||
|
elif non_linearity == "relu":
|
||||||
|
self.nonlinearity = nn.ReLU()
|
||||||
|
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
self.conv_shortcut = None
|
||||||
|
if conv_shortcut:
|
||||||
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
else:
|
||||||
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
|
||||||
|
|
||||||
|
def forward(self, input_tensor, temb=None):
|
||||||
|
hidden_states = input_tensor
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.conv1(hidden_states)
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
temb = self.nonlinearity(temb)
|
||||||
|
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
if temb is not None and self.time_embedding_norm == "default":
|
||||||
|
hidden_states = hidden_states + temb
|
||||||
|
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
|
||||||
|
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||||
|
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||||
|
hidden_states = hidden_states * (1 + scale) + shift
|
||||||
|
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.conv2(hidden_states)
|
||||||
|
|
||||||
|
if self.conv_shortcut is not None:
|
||||||
|
input_tensor = self.conv_shortcut(input_tensor)
|
||||||
|
|
||||||
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Transformer Blocks =====
|
||||||
|
|
||||||
|
class GEGLU(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||||
|
return hidden_states * F.gelu(gate)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, dim_out=None, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.ModuleList([
|
||||||
|
GEGLU(dim, dim * 4),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
for module in self.net:
|
||||||
|
hidden_states = module(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""Attention block matching diffusers checkpoint key format.
|
||||||
|
Keys: to_q.weight, to_k.weight, to_v.weight, to_out.0.weight, to_out.0.bias
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=False,
|
||||||
|
upcast_attention=False,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.heads = heads
|
||||||
|
self.inner_dim = inner_dim
|
||||||
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||||
|
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
||||||
|
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
||||||
|
self.to_out = nn.ModuleList([
|
||||||
|
nn.Linear(inner_dim, query_dim, bias=True),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
# Query
|
||||||
|
query = self.to_q(hidden_states)
|
||||||
|
batch_size, seq_len, _ = query.shape
|
||||||
|
|
||||||
|
# Key/Value
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
key = self.to_k(encoder_hidden_states)
|
||||||
|
value = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
# Reshape for multi-head attention
|
||||||
|
head_dim = self.inner_dim // self.heads
|
||||||
|
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Scaled dot-product attention
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape back
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# Output projection
|
||||||
|
hidden_states = self.to_out[0](hidden_states)
|
||||||
|
hidden_states = self.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
dropout=0.0,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
upcast_attention=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.attn1 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
bias=False,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
self.attn2 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
bias=False,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
|
self.ff = FeedForward(dim, dropout=dropout)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
# Self-attention
|
||||||
|
attn_output = self.attn1(self.norm1(hidden_states))
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
# Cross-attention
|
||||||
|
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
# Feed-forward
|
||||||
|
ff_output = self.ff(self.norm3(hidden_states))
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer2DModel(nn.Module):
|
||||||
|
"""2D Transformer block wrapper matching diffusers checkpoint structure.
|
||||||
|
Keys: norm.weight/bias, proj_in.weight/bias, transformer_blocks.X.*, proj_out.weight/bias
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_attention_heads=16,
|
||||||
|
attention_head_dim=64,
|
||||||
|
in_channels=320,
|
||||||
|
num_layers=1,
|
||||||
|
dropout=0.0,
|
||||||
|
norm_num_groups=32,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
upcast_attention=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
|
||||||
|
self.proj_in = nn.Conv2d(in_channels, num_attention_heads * attention_head_dim, kernel_size=1, bias=True)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList([
|
||||||
|
BasicTransformerBlock(
|
||||||
|
dim=num_attention_heads * attention_head_dim,
|
||||||
|
n_heads=num_attention_heads,
|
||||||
|
d_head=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.proj_out = nn.Conv2d(num_attention_heads * attention_head_dim, in_channels, kernel_size=1, bias=True)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
batch, channel, height, width = hidden_states.shape
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
# Normalize and project to sequence
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
hidden_states = self.proj_in(hidden_states)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
|
||||||
|
|
||||||
|
# Transformer blocks
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
# Project back to 2D
|
||||||
|
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Down/Up Blocks =====
|
||||||
|
|
||||||
|
class CrossAttnDownBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=1,
|
||||||
|
downsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = True
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels_i = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels_i,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
attentions.append(
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=attention_head_dim,
|
||||||
|
attention_head_dim=out_channels // attention_head_dim,
|
||||||
|
in_channels=out_channels,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=resnet_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if downsample:
|
||||||
|
self.downsamplers = nn.ModuleList([
|
||||||
|
Downsample2D(out_channels, out_channels, padding=1)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.downsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
output_states = []
|
||||||
|
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, tuple(output_states)
|
||||||
|
|
||||||
|
|
||||||
|
class DownBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
downsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = False
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels_i = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels_i,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if downsample:
|
||||||
|
self.downsamplers = nn.ModuleList([
|
||||||
|
Downsample2D(out_channels, out_channels, padding=1)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.downsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
output_states = []
|
||||||
|
for resnet in self.resnets:
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, tuple(output_states)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttnUpBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
prev_output_channel,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=1,
|
||||||
|
upsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = True
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||||
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||||
|
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=resnet_in_channels + res_skip_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
attentions.append(
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=attention_head_dim,
|
||||||
|
attention_head_dim=out_channels // attention_head_dim,
|
||||||
|
in_channels=out_channels,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=resnet_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if upsample:
|
||||||
|
self.upsamplers = nn.ModuleList([
|
||||||
|
Upsample2D(out_channels, out_channels)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.upsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
# Pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class UpBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
prev_output_channel,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
upsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = False
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||||
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||||
|
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=resnet_in_channels + res_skip_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if upsample:
|
||||||
|
self.upsamplers = nn.ModuleList([
|
||||||
|
Upsample2D(out_channels, out_channels)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.upsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
||||||
|
for resnet in self.resnets:
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# ===== UNet Mid Block =====
|
||||||
|
|
||||||
|
class UNetMidBlock2DCrossAttn(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||||
|
|
||||||
|
# There is always at least one resnet
|
||||||
|
resnets = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for _ in range(num_layers):
|
||||||
|
attentions.append(
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=attention_head_dim,
|
||||||
|
attention_head_dim=in_channels // attention_head_dim,
|
||||||
|
in_channels=in_channels,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=resnet_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
hidden_states = self.resnets[0](hidden_states, temb)
|
||||||
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Downsample / Upsample =====
|
||||||
|
|
||||||
|
class Downsample2D(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, padding=1):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if self.padding == 0:
|
||||||
|
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample2D(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, upsample_size=None):
|
||||||
|
if upsample_size is not None:
|
||||||
|
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
|
||||||
|
else:
|
||||||
|
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== UNet2DConditionModel =====
|
||||||
|
|
||||||
|
class UNet2DConditionModel(nn.Module):
|
||||||
|
"""Stable Diffusion UNet with cross-attention conditioning.
|
||||||
|
state_dict keys match the diffusers UNet2DConditionModel checkpoint format.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_size=64,
|
||||||
|
in_channels=4,
|
||||||
|
out_channels=4,
|
||||||
|
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
|
||||||
|
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||||
|
block_out_channels=(320, 640, 1280, 1280),
|
||||||
|
layers_per_block=2,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=8,
|
||||||
|
norm_num_groups=32,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
dropout=0.0,
|
||||||
|
act_fn="silu",
|
||||||
|
time_embedding_type="positional",
|
||||||
|
flip_sin_to_cos=True,
|
||||||
|
freq_shift=0,
|
||||||
|
time_embedding_dim=None,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
upcast_attention=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.sample_size = sample_size
|
||||||
|
|
||||||
|
# Time embedding
|
||||||
|
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
|
||||||
|
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
|
||||||
|
time_embed_dim = block_out_channels[0] * 4
|
||||||
|
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
|
||||||
|
|
||||||
|
# Input
|
||||||
|
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
# Down blocks
|
||||||
|
self.down_blocks = nn.ModuleList()
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
for i, down_block_type in enumerate(down_block_types):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
if "CrossAttn" in down_block_type:
|
||||||
|
down_block = CrossAttnDownBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
downsample=not is_final_block,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
down_block = DownBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
downsample=not is_final_block,
|
||||||
|
)
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
# Mid block
|
||||||
|
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Up blocks
|
||||||
|
self.up_blocks = nn.ModuleList()
|
||||||
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||||
|
output_channel = reversed_block_out_channels[0]
|
||||||
|
|
||||||
|
for i, up_block_type in enumerate(up_block_types):
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
output_channel = reversed_block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
# in_channels for up blocks: diffusers uses reversed_block_out_channels[min(i+1, len-1)]
|
||||||
|
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||||
|
|
||||||
|
if "CrossAttn" in up_block_type:
|
||||||
|
up_block = CrossAttnUpBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block + 1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
upsample=not is_final_block,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
up_block = UpBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block + 1,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
upsample=not is_final_block,
|
||||||
|
)
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
|
||||||
|
# Output
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
|
||||||
|
# 1. Time embedding
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||||
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
t_emb = t_emb.to(dtype=sample.dtype)
|
||||||
|
emb = self.time_embedding(t_emb)
|
||||||
|
|
||||||
|
# 2. Pre-process
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
# 3. Down
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for down_block in self.down_blocks:
|
||||||
|
sample, res_samples = down_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
# 4. Mid
|
||||||
|
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
# 5. Up
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
res_samples = down_block_res_samples[-len(up_block.resnets):]
|
||||||
|
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
|
||||||
|
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
|
||||||
|
sample = up_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Post-process
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sample,)
|
||||||
|
return sample
|
||||||
642
diffsynth/models/stable_diffusion_vae.py
Normal file
642
diffsynth/models/stable_diffusion_vae.py
Normal file
@@ -0,0 +1,642 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class DiagonalGaussianDistribution:
|
||||||
|
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||||
|
self.parameters = parameters
|
||||||
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||||
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||||
|
self.deterministic = deterministic
|
||||||
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
|
self.var = torch.exp(self.logvar)
|
||||||
|
if self.deterministic:
|
||||||
|
self.var = self.std = torch.zeros_like(
|
||||||
|
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
||||||
|
# randn_like doesn't accept generator on all torch versions
|
||||||
|
sample = torch.randn(self.mean.shape, generator=generator,
|
||||||
|
device=self.parameters.device, dtype=self.parameters.dtype)
|
||||||
|
return self.mean + self.std * sample
|
||||||
|
|
||||||
|
def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.tensor([0.0])
|
||||||
|
if other is None:
|
||||||
|
return 0.5 * torch.sum(
|
||||||
|
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||||
|
dim=[1, 2, 3],
|
||||||
|
)
|
||||||
|
return 0.5 * torch.sum(
|
||||||
|
torch.pow(self.mean - other.mean, 2) / other.var
|
||||||
|
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||||
|
dim=[1, 2, 3],
|
||||||
|
)
|
||||||
|
|
||||||
|
def mode(self) -> torch.Tensor:
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout=0.0,
|
||||||
|
temb_channels=512,
|
||||||
|
groups=32,
|
||||||
|
groups_out=None,
|
||||||
|
pre_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
non_linearity="swish",
|
||||||
|
time_embedding_norm="default",
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
use_in_shortcut=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.pre_norm = pre_norm
|
||||||
|
self.time_embedding_norm = time_embedding_norm
|
||||||
|
self.output_scale_factor = output_scale_factor
|
||||||
|
|
||||||
|
if groups_out is None:
|
||||||
|
groups_out = groups
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
if temb_channels is not None:
|
||||||
|
if self.time_embedding_norm == "default":
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
|
||||||
|
elif self.time_embedding_norm == "scale_shift":
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
|
||||||
|
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
if non_linearity == "swish":
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
elif non_linearity == "silu":
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
elif non_linearity == "gelu":
|
||||||
|
self.nonlinearity = nn.GELU()
|
||||||
|
elif non_linearity == "relu":
|
||||||
|
self.nonlinearity = nn.ReLU()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported non_linearity: {non_linearity}")
|
||||||
|
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
self.conv_shortcut = None
|
||||||
|
if conv_shortcut:
|
||||||
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
else:
|
||||||
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
|
||||||
|
|
||||||
|
def forward(self, input_tensor, temb=None):
|
||||||
|
hidden_states = input_tensor
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.conv1(hidden_states)
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
temb = self.nonlinearity(temb)
|
||||||
|
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
if temb is not None and self.time_embedding_norm == "default":
|
||||||
|
hidden_states = hidden_states + temb
|
||||||
|
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
|
||||||
|
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||||
|
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||||
|
hidden_states = hidden_states * (1 + scale) + shift
|
||||||
|
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.conv2(hidden_states)
|
||||||
|
|
||||||
|
if self.conv_shortcut is not None:
|
||||||
|
input_tensor = self.conv_shortcut(input_tensor)
|
||||||
|
|
||||||
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class DownEncoderBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
add_downsample=True,
|
||||||
|
downsample_padding=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels_i = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels_i,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=None,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if add_downsample:
|
||||||
|
self.downsamplers = nn.ModuleList([
|
||||||
|
Downsample2D(out_channels, out_channels, padding=downsample_padding)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.downsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, *args, **kwargs):
|
||||||
|
for resnet in self.resnets:
|
||||||
|
hidden_states = resnet(hidden_states, temb=None)
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class UpDecoderBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
add_upsample=True,
|
||||||
|
temb_channels=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels_i = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels_i,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if add_upsample:
|
||||||
|
self.upsamplers = nn.ModuleList([
|
||||||
|
Upsample2D(out_channels, out_channels)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.upsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None):
|
||||||
|
for resnet in self.resnets:
|
||||||
|
hidden_states = resnet(hidden_states, temb=temb)
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class UNetMidBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
temb_channels=None,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
add_attention=True,
|
||||||
|
attention_head_dim=1,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||||
|
self.add_attention = add_attention
|
||||||
|
|
||||||
|
# there is always at least one resnet
|
||||||
|
resnets = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
if attention_head_dim is None:
|
||||||
|
attention_head_dim = in_channels
|
||||||
|
|
||||||
|
for _ in range(num_layers):
|
||||||
|
if self.add_attention:
|
||||||
|
attentions.append(
|
||||||
|
AttentionBlock(
|
||||||
|
in_channels,
|
||||||
|
num_groups=resnet_groups,
|
||||||
|
eps=resnet_eps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attentions.append(None)
|
||||||
|
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None):
|
||||||
|
hidden_states = self.resnets[0](hidden_states, temb)
|
||||||
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||||
|
if attn is not None:
|
||||||
|
hidden_states = attn(hidden_states)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
"""Simple attention block for VAE mid block.
|
||||||
|
Mirrors diffusers Attention class with AttnProcessor2_0 for VAE use case.
|
||||||
|
Uses modern key names (to_q, to_k, to_v, to_out) matching in-memory diffusers structure.
|
||||||
|
Checkpoint uses deprecated keys (query, key, value, proj_attn) — mapped via converter.
|
||||||
|
"""
|
||||||
|
def __init__(self, channels, num_groups=32, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.eps = eps
|
||||||
|
self.heads = 1
|
||||||
|
self.rescale_output_factor = 1.0
|
||||||
|
|
||||||
|
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=eps, affine=True)
|
||||||
|
self.to_q = nn.Linear(channels, channels, bias=True)
|
||||||
|
self.to_k = nn.Linear(channels, channels, bias=True)
|
||||||
|
self.to_v = nn.Linear(channels, channels, bias=True)
|
||||||
|
self.to_out = nn.ModuleList([
|
||||||
|
nn.Linear(channels, channels, bias=True),
|
||||||
|
nn.Dropout(0.0),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
# Group norm
|
||||||
|
hidden_states = self.group_norm(hidden_states)
|
||||||
|
|
||||||
|
# Flatten spatial dims: (B, C, H, W) -> (B, H*W, C)
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
# QKV projection
|
||||||
|
query = self.to_q(hidden_states)
|
||||||
|
key = self.to_k(hidden_states)
|
||||||
|
value = self.to_v(hidden_states)
|
||||||
|
|
||||||
|
# Reshape for attention: (B, seq, dim) -> (B, heads, seq, head_dim)
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // self.heads
|
||||||
|
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Scaled dot-product attention
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape back: (B, heads, seq, head_dim) -> (B, seq, heads*head_dim)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# Output projection + dropout
|
||||||
|
hidden_states = self.to_out[0](hidden_states)
|
||||||
|
hidden_states = self.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
# Reshape back to 4D and add residual
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
# Rescale output factor
|
||||||
|
hidden_states = hidden_states / self.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample2D(nn.Module):
|
||||||
|
"""Downsampling layer matching diffusers Downsample2D with use_conv=True.
|
||||||
|
Key names: conv.weight/bias.
|
||||||
|
When padding=0, applies explicit F.pad before conv to match dimension.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, padding=1):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if self.padding == 0:
|
||||||
|
import torch.nn.functional as F
|
||||||
|
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample2D(nn.Module):
|
||||||
|
"""Upsampling layer with key names matching diffusers checkpoint: conv.weight/bias."""
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=("DownEncoderBlock2D",),
|
||||||
|
block_out_channels=(64,),
|
||||||
|
layers_per_block=2,
|
||||||
|
norm_num_groups=32,
|
||||||
|
act_fn="silu",
|
||||||
|
double_z=True,
|
||||||
|
mid_block_add_attention=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers_per_block = layers_per_block
|
||||||
|
|
||||||
|
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList([])
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
for i, down_block_type in enumerate(down_block_types):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
down_block = DownEncoderBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
num_layers=self.layers_per_block,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
add_downsample=not is_final_block,
|
||||||
|
downsample_padding=0,
|
||||||
|
)
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
# mid
|
||||||
|
self.mid_block = UNetMidBlock2D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
output_scale_factor=1,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
attention_head_dim=block_out_channels[-1],
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
temb_channels=None,
|
||||||
|
add_attention=mid_block_add_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
# out
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
conv_out_channels = 2 * out_channels if double_z else out_channels
|
||||||
|
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, sample):
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
for down_block in self.down_blocks:
|
||||||
|
sample = down_block(sample)
|
||||||
|
sample = self.mid_block(sample)
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
up_block_types=("UpDecoderBlock2D",),
|
||||||
|
block_out_channels=(64,),
|
||||||
|
layers_per_block=2,
|
||||||
|
norm_num_groups=32,
|
||||||
|
act_fn="silu",
|
||||||
|
norm_type="group",
|
||||||
|
mid_block_add_attention=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers_per_block = layers_per_block
|
||||||
|
|
||||||
|
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.up_blocks = nn.ModuleList([])
|
||||||
|
temb_channels = in_channels if norm_type == "spatial" else None
|
||||||
|
|
||||||
|
# mid
|
||||||
|
self.mid_block = UNetMidBlock2D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
output_scale_factor=1,
|
||||||
|
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
||||||
|
attention_head_dim=block_out_channels[-1],
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
add_attention=mid_block_add_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
# up
|
||||||
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||||
|
output_channel = reversed_block_out_channels[0]
|
||||||
|
for i, up_block_type in enumerate(up_block_types):
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
output_channel = reversed_block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
up_block = UpDecoderBlock2D(
|
||||||
|
in_channels=prev_output_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
num_layers=self.layers_per_block + 1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
add_upsample=not is_final_block,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
)
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
|
||||||
|
# out
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, sample, latent_embeds=None):
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
sample = self.mid_block(sample, latent_embeds)
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
sample = up_block(sample, latent_embeds)
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionVAE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
|
||||||
|
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
|
||||||
|
block_out_channels=(128, 256, 512, 512),
|
||||||
|
layers_per_block=2,
|
||||||
|
act_fn="silu",
|
||||||
|
latent_channels=4,
|
||||||
|
norm_num_groups=32,
|
||||||
|
sample_size=512,
|
||||||
|
scaling_factor=0.18215,
|
||||||
|
shift_factor=None,
|
||||||
|
latents_mean=None,
|
||||||
|
latents_std=None,
|
||||||
|
force_upcast=True,
|
||||||
|
use_quant_conv=True,
|
||||||
|
use_post_quant_conv=True,
|
||||||
|
mid_block_add_attention=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = Encoder(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=latent_channels,
|
||||||
|
down_block_types=down_block_types,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
act_fn=act_fn,
|
||||||
|
double_z=True,
|
||||||
|
mid_block_add_attention=mid_block_add_attention,
|
||||||
|
)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
in_channels=latent_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
up_block_types=up_block_types,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
act_fn=act_fn,
|
||||||
|
mid_block_add_attention=mid_block_add_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
||||||
|
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
||||||
|
|
||||||
|
self.latents_mean = latents_mean
|
||||||
|
self.latents_std = latents_std
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.shift_factor = shift_factor
|
||||||
|
self.sample_size = sample_size
|
||||||
|
self.force_upcast = force_upcast
|
||||||
|
|
||||||
|
def _encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
if self.quant_conv is not None:
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self._encode(x)
|
||||||
|
posterior = DiagonalGaussianDistribution(h)
|
||||||
|
return posterior
|
||||||
|
|
||||||
|
def _decode(self, z):
|
||||||
|
if self.post_quant_conv is not None:
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
return self._decode(z)
|
||||||
|
|
||||||
|
def forward(self, sample, sample_posterior=True, return_dict=True, generator=None):
|
||||||
|
posterior = self.encode(sample)
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample(generator=generator)
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
# Scale latent
|
||||||
|
z = z * self.scaling_factor
|
||||||
|
decode = self.decode(z)
|
||||||
|
if return_dict:
|
||||||
|
return {"sample": decode, "posterior": posterior, "latent_sample": z}
|
||||||
|
return decode, posterior
|
||||||
62
diffsynth/models/stable_diffusion_xl_text_encoder.py
Normal file
62
diffsynth/models/stable_diffusion_xl_text_encoder.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLTextEncoder2(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=1280,
|
||||||
|
intermediate_size=5120,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=20,
|
||||||
|
max_position_embeddings=77,
|
||||||
|
vocab_size=49408,
|
||||||
|
layer_norm_eps=1e-05,
|
||||||
|
hidden_act="gelu",
|
||||||
|
initializer_factor=1.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=1,
|
||||||
|
projection_dim=1280,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import CLIPTextConfig, CLIPTextModelWithProjection
|
||||||
|
|
||||||
|
config = CLIPTextConfig(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_hidden_layers=num_hidden_layers,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
layer_norm_eps=layer_norm_eps,
|
||||||
|
hidden_act=hidden_act,
|
||||||
|
initializer_factor=initializer_factor,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
projection_dim=projection_dim,
|
||||||
|
)
|
||||||
|
self.model = CLIPTextModelWithProjection(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
output_hidden_states=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if output_hidden_states:
|
||||||
|
return outputs.text_embeds, outputs.hidden_states
|
||||||
|
return outputs.text_embeds
|
||||||
922
diffsynth/models/stable_diffusion_xl_unet.py
Normal file
922
diffsynth/models/stable_diffusion_xl_unet.py
Normal file
@@ -0,0 +1,922 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Time Embedding =====
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
|
self.freq_shift = freq_shift
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
half_dim = self.num_channels // 2
|
||||||
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||||
|
exponent = exponent / half_dim + self.freq_shift
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
sin_emb = torch.sin(emb)
|
||||||
|
cos_emb = torch.cos(emb)
|
||||||
|
if self.flip_sin_to_cos:
|
||||||
|
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||||
|
else:
|
||||||
|
emb = torch.cat([sin_emb, cos_emb], dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||||
|
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
|
||||||
|
out_dim = out_dim if out_dim is not None else time_embed_dim
|
||||||
|
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
|
||||||
|
|
||||||
|
def forward(self, sample):
|
||||||
|
sample = self.linear_1(sample)
|
||||||
|
sample = self.act(sample)
|
||||||
|
sample = self.linear_2(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
# ===== ResNet Blocks =====
|
||||||
|
|
||||||
|
class ResnetBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout=0.0,
|
||||||
|
temb_channels=512,
|
||||||
|
groups=32,
|
||||||
|
groups_out=None,
|
||||||
|
pre_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
non_linearity="swish",
|
||||||
|
time_embedding_norm="default",
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
use_in_shortcut=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.pre_norm = pre_norm
|
||||||
|
self.time_embedding_norm = time_embedding_norm
|
||||||
|
self.output_scale_factor = output_scale_factor
|
||||||
|
|
||||||
|
if groups_out is None:
|
||||||
|
groups_out = groups
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
if temb_channels is not None:
|
||||||
|
if self.time_embedding_norm == "default":
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
|
||||||
|
elif self.time_embedding_norm == "scale_shift":
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
|
||||||
|
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
if non_linearity == "swish":
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
elif non_linearity == "silu":
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
elif non_linearity == "gelu":
|
||||||
|
self.nonlinearity = nn.GELU()
|
||||||
|
elif non_linearity == "relu":
|
||||||
|
self.nonlinearity = nn.ReLU()
|
||||||
|
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
self.conv_shortcut = None
|
||||||
|
if conv_shortcut:
|
||||||
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
else:
|
||||||
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
|
||||||
|
|
||||||
|
def forward(self, input_tensor, temb=None):
|
||||||
|
hidden_states = input_tensor
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.conv1(hidden_states)
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
temb = self.nonlinearity(temb)
|
||||||
|
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
if temb is not None and self.time_embedding_norm == "default":
|
||||||
|
hidden_states = hidden_states + temb
|
||||||
|
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
|
||||||
|
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||||
|
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||||
|
hidden_states = hidden_states * (1 + scale) + shift
|
||||||
|
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.conv2(hidden_states)
|
||||||
|
|
||||||
|
if self.conv_shortcut is not None:
|
||||||
|
input_tensor = self.conv_shortcut(input_tensor)
|
||||||
|
|
||||||
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Transformer Blocks =====
|
||||||
|
|
||||||
|
class GEGLU(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||||
|
return hidden_states * F.gelu(gate)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, dim_out=None, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.ModuleList([
|
||||||
|
GEGLU(dim, dim * 4),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
for module in self.net:
|
||||||
|
hidden_states = module(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=False,
|
||||||
|
upcast_attention=False,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.heads = heads
|
||||||
|
self.inner_dim = inner_dim
|
||||||
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||||
|
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
||||||
|
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
||||||
|
self.to_out = nn.ModuleList([
|
||||||
|
nn.Linear(inner_dim, query_dim, bias=True),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
query = self.to_q(hidden_states)
|
||||||
|
batch_size, seq_len, _ = query.shape
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
key = self.to_k(encoder_hidden_states)
|
||||||
|
value = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
head_dim = self.inner_dim // self.heads
|
||||||
|
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
hidden_states = self.to_out[0](hidden_states)
|
||||||
|
hidden_states = self.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
dropout=0.0,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
upcast_attention=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.attn1 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
bias=False,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
self.attn2 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
bias=False,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
|
self.ff = FeedForward(dim, dropout=dropout)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
attn_output = self.attn1(self.norm1(hidden_states))
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
ff_output = self.ff(self.norm3(hidden_states))
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer2DModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_attention_heads=16,
|
||||||
|
attention_head_dim=64,
|
||||||
|
in_channels=320,
|
||||||
|
num_layers=1,
|
||||||
|
dropout=0.0,
|
||||||
|
norm_num_groups=32,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
upcast_attention=False,
|
||||||
|
use_linear_projection=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.use_linear_projection = use_linear_projection
|
||||||
|
|
||||||
|
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
|
||||||
|
|
||||||
|
if use_linear_projection:
|
||||||
|
self.proj_in = nn.Linear(in_channels, inner_dim, bias=True)
|
||||||
|
else:
|
||||||
|
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, bias=True)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList([
|
||||||
|
BasicTransformerBlock(
|
||||||
|
dim=inner_dim,
|
||||||
|
n_heads=num_attention_heads,
|
||||||
|
d_head=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
if use_linear_projection:
|
||||||
|
self.proj_out = nn.Linear(inner_dim, in_channels, bias=True)
|
||||||
|
else:
|
||||||
|
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, bias=True)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
batch, channel, height, width = hidden_states.shape
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
if self.use_linear_projection:
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
|
||||||
|
hidden_states = self.proj_in(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = self.proj_in(hidden_states)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
|
||||||
|
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
if self.use_linear_projection:
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
|
||||||
|
else:
|
||||||
|
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Down/Up Blocks =====
|
||||||
|
|
||||||
|
class CrossAttnDownBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=1,
|
||||||
|
downsample=True,
|
||||||
|
use_linear_projection=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = True
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels_i = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels_i,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
attentions.append(
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=attention_head_dim,
|
||||||
|
attention_head_dim=out_channels // attention_head_dim,
|
||||||
|
in_channels=out_channels,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=resnet_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if downsample:
|
||||||
|
self.downsamplers = nn.ModuleList([
|
||||||
|
Downsample2D(out_channels, out_channels, padding=1)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.downsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
output_states = []
|
||||||
|
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, tuple(output_states)
|
||||||
|
|
||||||
|
|
||||||
|
class DownBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
downsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = False
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels_i = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels_i,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if downsample:
|
||||||
|
self.downsamplers = nn.ModuleList([
|
||||||
|
Downsample2D(out_channels, out_channels, padding=1)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.downsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
output_states = []
|
||||||
|
for resnet in self.resnets:
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, tuple(output_states)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttnUpBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
prev_output_channel,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=1,
|
||||||
|
upsample=True,
|
||||||
|
use_linear_projection=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = True
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||||
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||||
|
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=resnet_in_channels + res_skip_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
attentions.append(
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=attention_head_dim,
|
||||||
|
attention_head_dim=out_channels // attention_head_dim,
|
||||||
|
in_channels=out_channels,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=resnet_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if upsample:
|
||||||
|
self.upsamplers = nn.ModuleList([
|
||||||
|
Upsample2D(out_channels, out_channels)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.upsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class UpBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
prev_output_channel,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
upsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = False
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||||
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||||
|
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=resnet_in_channels + res_skip_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if upsample:
|
||||||
|
self.upsamplers = nn.ModuleList([
|
||||||
|
Upsample2D(out_channels, out_channels)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.upsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
||||||
|
for resnet in self.resnets:
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# ===== UNet Mid Block =====
|
||||||
|
|
||||||
|
class UNetMidBlock2DCrossAttn(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=1,
|
||||||
|
use_linear_projection=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||||
|
|
||||||
|
resnets = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for _ in range(num_layers):
|
||||||
|
attentions.append(
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=attention_head_dim,
|
||||||
|
attention_head_dim=in_channels // attention_head_dim,
|
||||||
|
in_channels=in_channels,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=resnet_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
hidden_states = self.resnets[0](hidden_states, temb)
|
||||||
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Downsample / Upsample =====
|
||||||
|
|
||||||
|
class Downsample2D(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, padding=1):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if self.padding == 0:
|
||||||
|
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample2D(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, upsample_size=None):
|
||||||
|
if upsample_size is not None:
|
||||||
|
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
|
||||||
|
else:
|
||||||
|
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== SDXL UNet2DConditionModel =====
|
||||||
|
|
||||||
|
class SDXLUNet2DConditionModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_size=128,
|
||||||
|
in_channels=4,
|
||||||
|
out_channels=4,
|
||||||
|
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
|
||||||
|
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
|
||||||
|
block_out_channels=(320, 640, 1280),
|
||||||
|
layers_per_block=2,
|
||||||
|
cross_attention_dim=2048,
|
||||||
|
attention_head_dim=5,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
norm_num_groups=32,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
dropout=0.0,
|
||||||
|
act_fn="silu",
|
||||||
|
time_embedding_type="positional",
|
||||||
|
flip_sin_to_cos=True,
|
||||||
|
freq_shift=0,
|
||||||
|
time_embedding_dim=None,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
upcast_attention=False,
|
||||||
|
use_linear_projection=False,
|
||||||
|
addition_embed_type=None,
|
||||||
|
addition_time_embed_dim=None,
|
||||||
|
projection_class_embeddings_input_dim=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.sample_size = sample_size
|
||||||
|
self.addition_embed_type = addition_embed_type
|
||||||
|
|
||||||
|
if isinstance(attention_head_dim, int):
|
||||||
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||||
|
if isinstance(transformer_layers_per_block, int):
|
||||||
|
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||||
|
|
||||||
|
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
|
||||||
|
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
|
||||||
|
time_embed_dim = block_out_channels[0] * 4
|
||||||
|
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
|
||||||
|
|
||||||
|
if addition_embed_type == "text_time":
|
||||||
|
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
|
||||||
|
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||||
|
|
||||||
|
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList()
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
for i, down_block_type in enumerate(down_block_types):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
if "CrossAttn" in down_block_type:
|
||||||
|
down_block = CrossAttnDownBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim[i],
|
||||||
|
downsample=not is_final_block,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
down_block = DownBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
downsample=not is_final_block,
|
||||||
|
)
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim[-1],
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.up_blocks = nn.ModuleList()
|
||||||
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||||
|
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||||
|
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
||||||
|
output_channel = reversed_block_out_channels[0]
|
||||||
|
|
||||||
|
for i, up_block_type in enumerate(up_block_types):
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
output_channel = reversed_block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||||
|
|
||||||
|
if "CrossAttn" in up_block_type:
|
||||||
|
up_block = CrossAttnUpBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block + 1,
|
||||||
|
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=reversed_attention_head_dim[i],
|
||||||
|
upsample=not is_final_block,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
up_block = UpBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block + 1,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
upsample=not is_final_block,
|
||||||
|
)
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||||
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
t_emb = t_emb.to(dtype=sample.dtype)
|
||||||
|
emb = self.time_embedding(t_emb)
|
||||||
|
|
||||||
|
if self.addition_embed_type == "text_time":
|
||||||
|
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||||
|
time_ids = added_cond_kwargs.get("time_ids")
|
||||||
|
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||||
|
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||||
|
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||||
|
add_embeds = add_embeds.to(emb.dtype)
|
||||||
|
aug_emb = self.add_embedding(add_embeds)
|
||||||
|
emb = emb + aug_emb
|
||||||
|
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for down_block in self.down_blocks:
|
||||||
|
sample, res_samples = down_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
res_samples = down_block_res_samples[-len(up_block.resnets):]
|
||||||
|
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
|
||||||
|
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
|
||||||
|
sample = up_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sample,)
|
||||||
|
return sample
|
||||||
@@ -74,7 +74,7 @@ class AnimaImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str = "",
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
cfg_scale: float = 4.0,
|
cfg_scale: float = 4.0,
|
||||||
# Image
|
# Image
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str = "",
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
cfg_scale: float = 1.0,
|
cfg_scale: float = 1.0,
|
||||||
embedded_guidance: float = 4.0,
|
embedded_guidance: float = 4.0,
|
||||||
@@ -83,7 +83,7 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
input_image: Image.Image = None,
|
input_image: Image.Image = None,
|
||||||
denoising_strength: float = 1.0,
|
denoising_strength: float = 1.0,
|
||||||
# Edit
|
# Edit
|
||||||
edit_image: List[Image.Image] = None,
|
edit_image: Union[Image.Image, List[Image.Image]] = None,
|
||||||
edit_image_auto_resize: bool = True,
|
edit_image_auto_resize: bool = True,
|
||||||
# Shape
|
# Shape
|
||||||
height: int = 1024,
|
height: int = 1024,
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str = "",
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
cfg_scale: float = 1.0,
|
cfg_scale: float = 1.0,
|
||||||
embedded_guidance: float = 3.5,
|
embedded_guidance: float = 3.5,
|
||||||
@@ -199,6 +199,10 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
sigma_shift: float = None,
|
sigma_shift: float = None,
|
||||||
# Steps
|
# Steps
|
||||||
num_inference_steps: int = 30,
|
num_inference_steps: int = 30,
|
||||||
|
# local prompts
|
||||||
|
multidiffusion_prompts=(),
|
||||||
|
multidiffusion_masks=(),
|
||||||
|
multidiffusion_scales=(),
|
||||||
# Kontext
|
# Kontext
|
||||||
kontext_images: Union[list[Image.Image], Image.Image] = None,
|
kontext_images: Union[list[Image.Image], Image.Image] = None,
|
||||||
# ControlNet
|
# ControlNet
|
||||||
@@ -253,6 +257,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
"height": height, "width": width,
|
"height": height, "width": width,
|
||||||
"seed": seed, "rand_device": rand_device,
|
"seed": seed, "rand_device": rand_device,
|
||||||
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
|
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
|
||||||
|
"multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
|
||||||
"kontext_images": kontext_images,
|
"kontext_images": kontext_images,
|
||||||
"controlnet_inputs": controlnet_inputs,
|
"controlnet_inputs": controlnet_inputs,
|
||||||
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,
|
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,
|
||||||
|
|||||||
@@ -169,46 +169,46 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str = "",
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: Optional[str] = "",
|
||||||
denoising_strength: float = 1.0,
|
denoising_strength: float = 1.0,
|
||||||
# Image-to-video
|
# Image-to-video
|
||||||
input_images: list[Image.Image] = None,
|
input_images: Optional[list[Image.Image]] = None,
|
||||||
input_images_indexes: list[int] = [0],
|
input_images_indexes: Optional[list[int]] = [0],
|
||||||
input_images_strength: float = 1.0,
|
input_images_strength: Optional[float] = 1.0,
|
||||||
# In-Context Video Control
|
# In-Context Video Control
|
||||||
in_context_videos: list[list[Image.Image]] = None,
|
in_context_videos: Optional[list[list[Image.Image]]] = None,
|
||||||
in_context_downsample_factor: int = 2,
|
in_context_downsample_factor: Optional[int] = 2,
|
||||||
# Video-to-video
|
# Video-to-video
|
||||||
retake_video: list[Image.Image] = None,
|
retake_video: Optional[list[Image.Image]] = None,
|
||||||
retake_video_regions: list[tuple[float, float]] = None,
|
retake_video_regions: Optional[list[tuple[float, float]]] = None,
|
||||||
# Audio-to-video
|
# Audio-to-video
|
||||||
retake_audio: torch.Tensor = None,
|
retake_audio: Optional[torch.Tensor] = None,
|
||||||
audio_sample_rate: int = 48000,
|
audio_sample_rate: Optional[int] = 48000,
|
||||||
retake_audio_regions: list[tuple[float, float]] = None,
|
retake_audio_regions: Optional[list[tuple[float, float]]] = None,
|
||||||
# Randomness
|
# Randomness
|
||||||
seed: int = None,
|
seed: Optional[int] = None,
|
||||||
rand_device: str = "cpu",
|
rand_device: Optional[str] = "cpu",
|
||||||
# Shape
|
# Shape
|
||||||
height: int = 512,
|
height: Optional[int] = 512,
|
||||||
width: int = 768,
|
width: Optional[int] = 768,
|
||||||
num_frames: int = 121,
|
num_frames: Optional[int] = 121,
|
||||||
frame_rate: int = 24,
|
frame_rate: Optional[int] = 24,
|
||||||
# Classifier-free guidance
|
# Classifier-free guidance
|
||||||
cfg_scale: float = 3.0,
|
cfg_scale: Optional[float] = 3.0,
|
||||||
# Scheduler
|
# Scheduler
|
||||||
num_inference_steps: int = 30,
|
num_inference_steps: Optional[int] = 30,
|
||||||
# VAE tiling
|
# VAE tiling
|
||||||
tiled: bool = True,
|
tiled: Optional[bool] = True,
|
||||||
tile_size_in_pixels: int = 512,
|
tile_size_in_pixels: Optional[int] = 512,
|
||||||
tile_overlap_in_pixels: int = 128,
|
tile_overlap_in_pixels: Optional[int] = 128,
|
||||||
tile_size_in_frames: int = 128,
|
tile_size_in_frames: Optional[int] = 128,
|
||||||
tile_overlap_in_frames: int = 24,
|
tile_overlap_in_frames: Optional[int] = 24,
|
||||||
# Special Pipelines
|
# Special Pipelines
|
||||||
use_two_stage_pipeline: bool = False,
|
use_two_stage_pipeline: Optional[bool] = False,
|
||||||
stage2_spatial_upsample_factor: int = 2,
|
stage2_spatial_upsample_factor: Optional[int] = 2,
|
||||||
clear_lora_before_state_two: bool = False,
|
clear_lora_before_state_two: Optional[bool] = False,
|
||||||
use_distilled_pipeline: bool = False,
|
use_distilled_pipeline: Optional[bool] = False,
|
||||||
# progress_bar
|
# progress_bar
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -115,33 +115,33 @@ class MovaAudioVideoPipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str = "",
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: Optional[str] = "",
|
||||||
# Image-to-video
|
# Image-to-video
|
||||||
input_image: Image.Image = None,
|
input_image: Optional[Image.Image] = None,
|
||||||
# First-last-frame-to-video
|
# First-last-frame-to-video
|
||||||
end_image: Image.Image = None,
|
end_image: Optional[Image.Image] = None,
|
||||||
# Video-to-video
|
# Video-to-video
|
||||||
denoising_strength: float = 1.0,
|
denoising_strength: Optional[float] = 1.0,
|
||||||
# Randomness
|
# Randomness
|
||||||
seed: int = None,
|
seed: Optional[int] = None,
|
||||||
rand_device: str = "cpu",
|
rand_device: Optional[str] = "cpu",
|
||||||
# Shape
|
# Shape
|
||||||
height: int = 352,
|
height: Optional[int] = 352,
|
||||||
width: int = 640,
|
width: Optional[int] = 640,
|
||||||
num_frames: int = 81,
|
num_frames: Optional[int] = 81,
|
||||||
frame_rate: int = 24,
|
frame_rate: Optional[int] = 24,
|
||||||
# Classifier-free guidance
|
# Classifier-free guidance
|
||||||
cfg_scale: float = 5.0,
|
cfg_scale: Optional[float] = 5.0,
|
||||||
# Boundary
|
# Boundary
|
||||||
switch_DiT_boundary: float = 0.9,
|
switch_DiT_boundary: Optional[float] = 0.9,
|
||||||
# Scheduler
|
# Scheduler
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: Optional[int] = 50,
|
||||||
sigma_shift: float = 5.0,
|
sigma_shift: Optional[float] = 5.0,
|
||||||
# VAE tiling
|
# VAE tiling
|
||||||
tiled: bool = True,
|
tiled: Optional[bool] = True,
|
||||||
tile_size: tuple[int, int] = (30, 52),
|
tile_size: Optional[tuple[int, int]] = (30, 52),
|
||||||
tile_stride: tuple[int, int] = (15, 26),
|
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
||||||
# progress_bar
|
# progress_bar
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str = "",
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
cfg_scale: float = 4.0,
|
cfg_scale: float = 4.0,
|
||||||
# Image
|
# Image
|
||||||
|
|||||||
230
diffsynth/pipelines/stable_diffusion.py
Normal file
230
diffsynth/pipelines/stable_diffusion.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..diffusion.ddim_scheduler import DDIMScheduler
|
||||||
|
from ..core import ModelConfig
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, CLIPTextModel
|
||||||
|
from ..models.stable_diffusion_text_encoder import SDTextEncoder
|
||||||
|
from ..models.stable_diffusion_unet import UNet2DConditionModel
|
||||||
|
from ..models.stable_diffusion_vae import StableDiffusionVAE
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionPipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device=get_device_type(), torch_dtype=torch.float16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=8, width_division_factor=8,
|
||||||
|
)
|
||||||
|
self.scheduler = DDIMScheduler()
|
||||||
|
self.text_encoder: SDTextEncoder = None
|
||||||
|
self.unet: UNet2DConditionModel = None
|
||||||
|
self.vae: StableDiffusionVAE = None
|
||||||
|
self.tokenizer: AutoTokenizer = None
|
||||||
|
|
||||||
|
self.in_iteration_models = ("unet",)
|
||||||
|
self.units = [
|
||||||
|
SDUnit_ShapeChecker(),
|
||||||
|
SDUnit_PromptEmbedder(),
|
||||||
|
SDUnit_NoiseInitializer(),
|
||||||
|
SDUnit_InputImageEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_stable_diffusion
|
||||||
|
self.compilable_models = ["unet"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.float16,
|
||||||
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
pipe = StableDiffusionPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
# Override vram_config to use the specified torch_dtype for all models
|
||||||
|
for mc in model_configs:
|
||||||
|
mc._vram_config_override = {
|
||||||
|
'onload_dtype': torch_dtype,
|
||||||
|
'computation_dtype': torch_dtype,
|
||||||
|
}
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder")
|
||||||
|
pipe.unet = model_pool.fetch_model("stable_diffusion_unet")
|
||||||
|
pipe.vae = model_pool.fetch_model("stable_diffusion_vae")
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 7.5,
|
||||||
|
height: int = 512,
|
||||||
|
width: int = 512,
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cpu",
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
eta: float = 0.0,
|
||||||
|
guidance_rescale: float = 0.0,
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
):
|
||||||
|
# 1. Scheduler
|
||||||
|
self.scheduler.set_timesteps(
|
||||||
|
num_inference_steps, eta=eta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Three-dict input preparation
|
||||||
|
inputs_posi = {"prompt": prompt}
|
||||||
|
inputs_nega = {"negative_prompt": negative_prompt}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"guidance_rescale": guidance_rescale,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. Unit chain execution
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||||
|
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Denoise loop
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(
|
||||||
|
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. VAE decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
latents = inputs_shared["latents"] / self.vae.scaling_factor
|
||||||
|
image = self.vae.decode(latents)
|
||||||
|
image = self.vae_output_to_image(image)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class SDUnit_ShapeChecker(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("height", "width"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionPipeline, height, width):
|
||||||
|
height, width = pipe.check_resize_height_width(height, width)
|
||||||
|
return {"height": height, "width": width}
|
||||||
|
|
||||||
|
|
||||||
|
class SDUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
|
output_params=("prompt_embeds",),
|
||||||
|
onload_model_names=("text_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
pipe: StableDiffusionPipeline,
|
||||||
|
prompt: str,
|
||||||
|
device: torch.device,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
text_inputs = pipe.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=pipe.tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
text_input_ids = text_inputs.input_ids.to(device)
|
||||||
|
prompt_embeds = pipe.text_encoder(text_input_ids)
|
||||||
|
# TextEncoder returns (last_hidden_state, hidden_states) or just last_hidden_state.
|
||||||
|
# last_hidden_state is the post-final-layer-norm output, matching diffusers encode_prompt.
|
||||||
|
if isinstance(prompt_embeds, tuple):
|
||||||
|
prompt_embeds = prompt_embeds[0]
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionPipeline, prompt):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||||
|
return {"prompt_embeds": prompt_embeds}
|
||||||
|
|
||||||
|
|
||||||
|
class SDUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionPipeline, height, width, seed, rand_device):
|
||||||
|
noise = pipe.generate_noise(
|
||||||
|
(1, pipe.unet.in_channels, height // 8, width // 8),
|
||||||
|
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
|
||||||
|
)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class SDUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionPipeline, input_image, noise):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
input_tensor = pipe.preprocess_image(input_image)
|
||||||
|
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
|
||||||
|
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
return {"latents": latents}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_stable_diffusion(
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
latents=None,
|
||||||
|
timestep=None,
|
||||||
|
prompt_embeds=None,
|
||||||
|
cross_attention_kwargs=None,
|
||||||
|
timestep_cond=None,
|
||||||
|
added_cond_kwargs=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# SD timestep is already in 0-999 range, no scaling needed
|
||||||
|
noise_pred = unet(
|
||||||
|
latents,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
timestep_cond=timestep_cond,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
return noise_pred
|
||||||
331
diffsynth/pipelines/stable_diffusion_xl.py
Normal file
331
diffsynth/pipelines/stable_diffusion_xl.py
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..diffusion.ddim_scheduler import DDIMScheduler
|
||||||
|
from ..core import ModelConfig
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, CLIPTextModel
|
||||||
|
from ..models.stable_diffusion_text_encoder import SDTextEncoder
|
||||||
|
from ..models.stable_diffusion_xl_unet import SDXLUNet2DConditionModel
|
||||||
|
from ..models.stable_diffusion_xl_text_encoder import SDXLTextEncoder2
|
||||||
|
from ..models.stable_diffusion_vae import StableDiffusionVAE
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||||
|
"""Rescale noise_cfg based on guidance_rescale to prevent overexposure.
|
||||||
|
|
||||||
|
Based on Section 3.4 from "Common Diffusion Noise Schedules and Sample Steps are Flawed"
|
||||||
|
https://huggingface.co/papers/2305.08891
|
||||||
|
"""
|
||||||
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||||
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||||
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||||
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||||
|
return noise_cfg
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLPipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=8, width_division_factor=8,
|
||||||
|
)
|
||||||
|
self.scheduler = DDIMScheduler()
|
||||||
|
self.text_encoder: SDTextEncoder = None
|
||||||
|
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||||
|
self.unet: SDXLUNet2DConditionModel = None
|
||||||
|
self.vae: StableDiffusionVAE = None
|
||||||
|
self.tokenizer: AutoTokenizer = None
|
||||||
|
self.tokenizer_2: AutoTokenizer = None
|
||||||
|
|
||||||
|
self.in_iteration_models = ("unet",)
|
||||||
|
self.units = [
|
||||||
|
SDXLUnit_ShapeChecker(),
|
||||||
|
SDXLUnit_PromptEmbedder(),
|
||||||
|
SDXLUnit_NoiseInitializer(),
|
||||||
|
SDXLUnit_InputImageEmbedder(),
|
||||||
|
SDXLUnit_AddTimeIdsComputer(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_stable_diffusion_xl
|
||||||
|
self.compilable_models = ["unet"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = None,
|
||||||
|
tokenizer_2_config: ModelConfig = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
pipe = StableDiffusionXLPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
# Override vram_config to use the specified torch_dtype for all models
|
||||||
|
for mc in model_configs:
|
||||||
|
mc._vram_config_override = {
|
||||||
|
'onload_dtype': torch_dtype,
|
||||||
|
'computation_dtype': torch_dtype,
|
||||||
|
}
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder")
|
||||||
|
pipe.text_encoder_2 = model_pool.fetch_model("stable_diffusion_xl_text_encoder")
|
||||||
|
pipe.unet = model_pool.fetch_model("stable_diffusion_xl_unet")
|
||||||
|
pipe.vae = model_pool.fetch_model("stable_diffusion_xl_vae")
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
|
if tokenizer_2_config is not None:
|
||||||
|
tokenizer_2_config.download_if_necessary()
|
||||||
|
pipe.tokenizer_2 = AutoTokenizer.from_pretrained(tokenizer_2_config.path)
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 5.0,
|
||||||
|
height: int = 1024,
|
||||||
|
width: int = 1024,
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cpu",
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
guidance_rescale: float = 0.0,
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
):
|
||||||
|
# 1. Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
# 2. Three-dict input preparation
|
||||||
|
inputs_posi = {
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
inputs_nega = {
|
||||||
|
"prompt": negative_prompt,
|
||||||
|
}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"guidance_rescale": guidance_rescale,
|
||||||
|
"crops_coords_top_left": (0, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. Unit chain execution
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||||
|
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Denoise loop
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply guidance_rescale
|
||||||
|
if guidance_rescale > 0.0:
|
||||||
|
# cfg_guided_model_fn already applied CFG, now apply rescale
|
||||||
|
# We need the text-only prediction for rescale
|
||||||
|
noise_pred_text = self.model_fn(
|
||||||
|
self.unet,
|
||||||
|
inputs_shared["latents"],
|
||||||
|
timestep,
|
||||||
|
inputs_posi["prompt_embeds"],
|
||||||
|
pooled_prompt_embeds=inputs_posi["pooled_prompt_embeds"],
|
||||||
|
add_time_ids=inputs_posi["add_time_ids"],
|
||||||
|
)
|
||||||
|
noise_pred = rescale_noise_cfg(
|
||||||
|
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_shared["latents"] = self.step(
|
||||||
|
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. VAE decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
latents = inputs_shared["latents"] / self.vae.scaling_factor
|
||||||
|
image = self.vae.decode(latents)
|
||||||
|
image = self.vae_output_to_image(image)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUnit_ShapeChecker(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("height", "width"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionXLPipeline, height, width):
|
||||||
|
height, width = pipe.check_resize_height_width(height, width)
|
||||||
|
return {"height": height, "width": width}
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "prompt"},
|
||||||
|
output_params=("prompt_embeds", "pooled_prompt_embeds"),
|
||||||
|
onload_model_names=("text_encoder", "text_encoder_2")
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
pipe: StableDiffusionXLPipeline,
|
||||||
|
prompt: str,
|
||||||
|
device: torch.device,
|
||||||
|
) -> tuple:
|
||||||
|
"""Encode prompt using both text encoders (same prompt for both).
|
||||||
|
|
||||||
|
Returns (prompt_embeds, pooled_prompt_embeds):
|
||||||
|
- prompt_embeds: concat(encoder1_output, encoder2_output) -> (B, 77, 2048)
|
||||||
|
- pooled_prompt_embeds: encoder2 pooled output -> (B, 1280)
|
||||||
|
"""
|
||||||
|
# Text Encoder 1 (CLIP-L, 768-dim)
|
||||||
|
text_input_ids_1 = pipe.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=pipe.tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_ids.to(device)
|
||||||
|
prompt_embeds_1 = pipe.text_encoder(text_input_ids_1)
|
||||||
|
if isinstance(prompt_embeds_1, tuple):
|
||||||
|
prompt_embeds_1 = prompt_embeds_1[0]
|
||||||
|
|
||||||
|
# Text Encoder 2 (CLIP-bigG, 1280-dim) — uses penultimate hidden states + pooled
|
||||||
|
text_input_ids_2 = pipe.tokenizer_2(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=pipe.tokenizer_2.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_ids.to(device)
|
||||||
|
# SDXLTextEncoder2 forward returns (text_embeds/pooled, hidden_states_tuple)
|
||||||
|
pooled_prompt_embeds, hidden_states = pipe.text_encoder_2(text_input_ids_2, output_hidden_states=True)
|
||||||
|
# Use penultimate hidden state (same as diffusers: hidden_states[-2])
|
||||||
|
prompt_embeds_2 = hidden_states[-2]
|
||||||
|
|
||||||
|
# Concatenate both encoder outputs along feature dimension
|
||||||
|
prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
|
||||||
|
|
||||||
|
return prompt_embeds, pooled_prompt_embeds
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionXLPipeline, prompt):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||||
|
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionXLPipeline, height, width, seed, rand_device):
|
||||||
|
noise = pipe.generate_noise(
|
||||||
|
(1, pipe.unet.in_channels, height // 8, width // 8),
|
||||||
|
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
|
||||||
|
)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionXLPipeline, input_image, noise):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
input_tensor = pipe.preprocess_image(input_image)
|
||||||
|
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
|
||||||
|
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
return {"latents": latents}
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUnit_AddTimeIdsComputer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("add_time_ids",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_add_time_ids(self, pipe, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim):
|
||||||
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||||
|
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
|
||||||
|
addition_time_embed_dim = pipe.unet.add_time_proj.num_channels
|
||||||
|
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||||
|
if expected_add_embed_dim != passed_add_embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
|
||||||
|
f"but a vector of {passed_add_embed_dim} was created."
|
||||||
|
)
|
||||||
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=pipe.device)
|
||||||
|
return add_time_ids
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionXLPipeline, height, width):
|
||||||
|
original_size = (height, width)
|
||||||
|
target_size = (height, width)
|
||||||
|
crops_coords_top_left = (0, 0)
|
||||||
|
|
||||||
|
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
|
||||||
|
add_time_ids = self._get_add_time_ids(
|
||||||
|
pipe, original_size, crops_coords_top_left, target_size,
|
||||||
|
dtype=pipe.torch_dtype,
|
||||||
|
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||||
|
)
|
||||||
|
return {"add_time_ids": add_time_ids}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_stable_diffusion_xl(
|
||||||
|
unet: SDXLUNet2DConditionModel,
|
||||||
|
latents=None,
|
||||||
|
timestep=None,
|
||||||
|
prompt_embeds=None,
|
||||||
|
pooled_prompt_embeds=None,
|
||||||
|
add_time_ids=None,
|
||||||
|
cross_attention_kwargs=None,
|
||||||
|
timestep_cond=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""SDXL model forward with added_cond_kwargs for micro-conditioning."""
|
||||||
|
added_cond_kwargs = {
|
||||||
|
"text_embeds": pooled_prompt_embeds,
|
||||||
|
"time_ids": add_time_ids,
|
||||||
|
}
|
||||||
|
noise_pred = unet(
|
||||||
|
latents,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
timestep_cond=timestep_cond,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
return noise_pred
|
||||||
@@ -190,82 +190,82 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str = "",
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: Optional[str] = "",
|
||||||
# Image-to-video
|
# Image-to-video
|
||||||
input_image: Image.Image = None,
|
input_image: Optional[Image.Image] = None,
|
||||||
# First-last-frame-to-video
|
# First-last-frame-to-video
|
||||||
end_image: Image.Image = None,
|
end_image: Optional[Image.Image] = None,
|
||||||
# Video-to-video
|
# Video-to-video
|
||||||
input_video: list[Image.Image] = None,
|
input_video: Optional[list[Image.Image]] = None,
|
||||||
denoising_strength: float = 1.0,
|
denoising_strength: Optional[float] = 1.0,
|
||||||
# Speech-to-video
|
# Speech-to-video
|
||||||
input_audio: np.array = None,
|
input_audio: Optional[np.array] = None,
|
||||||
audio_embeds: torch.Tensor = None,
|
audio_embeds: Optional[torch.Tensor] = None,
|
||||||
audio_sample_rate: int = 16000,
|
audio_sample_rate: Optional[int] = 16000,
|
||||||
s2v_pose_video: list[Image.Image] = None,
|
s2v_pose_video: Optional[list[Image.Image]] = None,
|
||||||
s2v_pose_latents: torch.Tensor = None,
|
s2v_pose_latents: Optional[torch.Tensor] = None,
|
||||||
motion_video: list[Image.Image] = None,
|
motion_video: Optional[list[Image.Image]] = None,
|
||||||
# ControlNet
|
# ControlNet
|
||||||
control_video: list[Image.Image] = None,
|
control_video: Optional[list[Image.Image]] = None,
|
||||||
reference_image: Image.Image = None,
|
reference_image: Optional[Image.Image] = None,
|
||||||
# Camera control
|
# Camera control
|
||||||
camera_control_direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"] = None,
|
camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None,
|
||||||
camera_control_speed: float = 1/54,
|
camera_control_speed: Optional[float] = 1/54,
|
||||||
camera_control_origin: tuple = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
|
camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
|
||||||
# VACE
|
# VACE
|
||||||
vace_video: list[Image.Image] = None,
|
vace_video: Optional[list[Image.Image]] = None,
|
||||||
vace_video_mask: Image.Image = None,
|
vace_video_mask: Optional[Image.Image] = None,
|
||||||
vace_reference_image: Image.Image = None,
|
vace_reference_image: Optional[Image.Image] = None,
|
||||||
vace_scale: float = 1.0,
|
vace_scale: Optional[float] = 1.0,
|
||||||
# Animate
|
# Animate
|
||||||
animate_pose_video: list[Image.Image] = None,
|
animate_pose_video: Optional[list[Image.Image]] = None,
|
||||||
animate_face_video: list[Image.Image] = None,
|
animate_face_video: Optional[list[Image.Image]] = None,
|
||||||
animate_inpaint_video: list[Image.Image] = None,
|
animate_inpaint_video: Optional[list[Image.Image]] = None,
|
||||||
animate_mask_video: list[Image.Image] = None,
|
animate_mask_video: Optional[list[Image.Image]] = None,
|
||||||
# VAP
|
# VAP
|
||||||
vap_video: list[Image.Image] = None,
|
vap_video: Optional[list[Image.Image]] = None,
|
||||||
vap_prompt: str = " ",
|
vap_prompt: Optional[str] = " ",
|
||||||
negative_vap_prompt: str = " ",
|
negative_vap_prompt: Optional[str] = " ",
|
||||||
# Randomness
|
# Randomness
|
||||||
seed: int = None,
|
seed: Optional[int] = None,
|
||||||
rand_device: str = "cpu",
|
rand_device: Optional[str] = "cpu",
|
||||||
# Shape
|
# Shape
|
||||||
height: int = 480,
|
height: Optional[int] = 480,
|
||||||
width: int = 832,
|
width: Optional[int] = 832,
|
||||||
num_frames: int = 81,
|
num_frames=81,
|
||||||
# Classifier-free guidance
|
# Classifier-free guidance
|
||||||
cfg_scale: float = 5.0,
|
cfg_scale: Optional[float] = 5.0,
|
||||||
cfg_merge: bool = False,
|
cfg_merge: Optional[bool] = False,
|
||||||
# Boundary
|
# Boundary
|
||||||
switch_DiT_boundary: float = 0.875,
|
switch_DiT_boundary: Optional[float] = 0.875,
|
||||||
# Scheduler
|
# Scheduler
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: Optional[int] = 50,
|
||||||
sigma_shift: float = 5.0,
|
sigma_shift: Optional[float] = 5.0,
|
||||||
# Speed control
|
# Speed control
|
||||||
motion_bucket_id: int = None,
|
motion_bucket_id: Optional[int] = None,
|
||||||
# LongCat-Video
|
# LongCat-Video
|
||||||
longcat_video: list[Image.Image] = None,
|
longcat_video: Optional[list[Image.Image]] = None,
|
||||||
# VAE tiling
|
# VAE tiling
|
||||||
tiled: bool = True,
|
tiled: Optional[bool] = True,
|
||||||
tile_size: tuple[int, int] = (30, 52),
|
tile_size: Optional[tuple[int, int]] = (30, 52),
|
||||||
tile_stride: tuple[int, int] = (15, 26),
|
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
||||||
# Sliding window
|
# Sliding window
|
||||||
sliding_window_size: int = None,
|
sliding_window_size: Optional[int] = None,
|
||||||
sliding_window_stride: int = None,
|
sliding_window_stride: Optional[int] = None,
|
||||||
# Teacache
|
# Teacache
|
||||||
tea_cache_l1_thresh: float = None,
|
tea_cache_l1_thresh: Optional[float] = None,
|
||||||
tea_cache_model_id: str = "",
|
tea_cache_model_id: Optional[str] = "",
|
||||||
# WanToDance
|
# WanToDance
|
||||||
wantodance_music_path: str = None,
|
wantodance_music_path: Optional[str] = None,
|
||||||
wantodance_reference_image: Image.Image = None,
|
wantodance_reference_image: Optional[Image.Image] = None,
|
||||||
wantodance_fps: float = 30,
|
wantodance_fps: Optional[float] = 30,
|
||||||
wantodance_keyframes: list[Image.Image] = None,
|
wantodance_keyframes: Optional[list[Image.Image]] = None,
|
||||||
wantodance_keyframes_mask: list[int] = None,
|
wantodance_keyframes_mask: Optional[list[int]] = None,
|
||||||
framewise_decoding: bool = False,
|
framewise_decoding: bool = False,
|
||||||
# progress_bar
|
# progress_bar
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
output_type: Literal["quantized", "floatpoint"] = "quantized",
|
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
|
||||||
):
|
):
|
||||||
# Scheduler
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ class ZImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str = "",
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
cfg_scale: float = 1.0,
|
cfg_scale: float = 1.0,
|
||||||
# Image
|
# Image
|
||||||
@@ -109,7 +109,7 @@ class ZImagePipeline(BasePipeline):
|
|||||||
width: int = 1024,
|
width: int = 1024,
|
||||||
# Randomness
|
# Randomness
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
rand_device: Union[str, torch.device] = "cpu",
|
rand_device: str = "cpu",
|
||||||
# Steps
|
# Steps
|
||||||
num_inference_steps: int = 8,
|
num_inference_steps: int = 8,
|
||||||
sigma_shift: float = None,
|
sigma_shift: float = None,
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
def DINOv3StateDictConverter(state_dict):
|
|
||||||
new_state_dict = {}
|
|
||||||
for key in state_dict:
|
|
||||||
value = state_dict[key]
|
|
||||||
if key.startswith("layer"):
|
|
||||||
new_state_dict["model." + key] = value
|
|
||||||
else:
|
|
||||||
new_state_dict[key] = value
|
|
||||||
return new_state_dict
|
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
def SDTextEncoderStateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith("text_model.") and "position_ids" not in key:
|
||||||
|
new_key = "model." + key
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
return new_state_dict
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
def SDVAEStateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if ".query." in key:
|
||||||
|
new_key = key.replace(".query.", ".to_q.")
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
elif ".key." in key:
|
||||||
|
new_key = key.replace(".key.", ".to_k.")
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
elif ".value." in key:
|
||||||
|
new_key = key.replace(".value.", ".to_v.")
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
elif ".proj_attn." in key:
|
||||||
|
new_key = key.replace(".proj_attn.", ".to_out.0.")
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
else:
|
||||||
|
new_state_dict[key] = state_dict[key]
|
||||||
|
return new_state_dict
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def SDXLTextEncoder2StateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key == "text_projection.weight":
|
||||||
|
val = state_dict[key]
|
||||||
|
new_state_dict["model.text_projection.weight"] = val.float() if val.dtype == torch.float16 else val
|
||||||
|
elif key.startswith("text_model.") and "position_ids" not in key:
|
||||||
|
new_key = "model." + key
|
||||||
|
val = state_dict[key]
|
||||||
|
new_state_dict[new_key] = val.float() if val.dtype == torch.float16 else val
|
||||||
|
return new_state_dict
|
||||||
141
docs/en/Model_Details/Stable-Diffusion-XL.md
Normal file
141
docs/en/Model_Details/Stable-Diffusion-XL.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Stable Diffusion XL
|
||||||
|
|
||||||
|
Stable Diffusion XL (SDXL) is an open-source diffusion-based text-to-image generation model developed by Stability AI, supporting 1024x1024 resolution high-quality text-to-image generation with a dual text encoder (CLIP-L + CLIP-bigG) architecture.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Before performing model inference and training, please install DiffSynth-Studio first.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Running the following code will quickly load the [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 6GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
|
||||||
|
|
||||||
|
## Model Inference
|
||||||
|
|
||||||
|
The model is loaded via `StableDiffusionXLPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||||
|
|
||||||
|
The input parameters for `StableDiffusionXLPipeline` inference include:
|
||||||
|
|
||||||
|
* `prompt`: Text prompt.
|
||||||
|
* `negative_prompt`: Negative prompt, defaults to an empty string.
|
||||||
|
* `cfg_scale`: Classifier-Free Guidance scale factor, default 5.0.
|
||||||
|
* `height`: Output image height, default 1024.
|
||||||
|
* `width`: Output image width, default 1024.
|
||||||
|
* `seed`: Random seed, defaults to a random value if not set.
|
||||||
|
* `rand_device`: Noise generation device, defaults to "cpu".
|
||||||
|
* `num_inference_steps`: Number of inference steps, default 50.
|
||||||
|
* `guidance_rescale`: Guidance rescale factor, default 0.0.
|
||||||
|
* `progress_bar_cmd`: Progress bar callback function.
|
||||||
|
|
||||||
|
> `StableDiffusionXLPipeline` requires dual tokenizer configurations (`tokenizer_config` and `tokenizer_2_config`), corresponding to the CLIP-L and CLIP-bigG text encoders.
|
||||||
|
|
||||||
|
## Model Training
|
||||||
|
|
||||||
|
Models in the stable_diffusion_xl series are trained via `examples/stable_diffusion_xl/model_training/train.py`. The script parameters include:
|
||||||
|
|
||||||
|
* General Training Parameters
|
||||||
|
* Dataset Configuration
|
||||||
|
* `--dataset_base_path`: Root directory of the dataset.
|
||||||
|
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||||
|
* `--dataset_repeat`: Number of dataset repeats per epoch.
|
||||||
|
* `--dataset_num_workers`: Number of processes per DataLoader.
|
||||||
|
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
|
||||||
|
* Model Loading Configuration
|
||||||
|
* `--model_paths`: Paths to load models from, in JSON format.
|
||||||
|
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
|
||||||
|
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
|
||||||
|
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
|
||||||
|
* Basic Training Configuration
|
||||||
|
* `--learning_rate`: Learning rate.
|
||||||
|
* `--num_epochs`: Number of epochs.
|
||||||
|
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||||
|
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
|
||||||
|
* `--weight_decay`: Weight decay magnitude.
|
||||||
|
* `--task`: Training task, defaults to `sft`.
|
||||||
|
* Output Configuration
|
||||||
|
* `--output_path`: Path to save the model.
|
||||||
|
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
|
||||||
|
* `--save_steps`: Interval in training steps to save the model.
|
||||||
|
* LoRA Configuration
|
||||||
|
* `--lora_base_model`: Which model to add LoRA to.
|
||||||
|
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||||
|
* `--lora_rank`: Rank of LoRA.
|
||||||
|
* `--lora_checkpoint`: Path to LoRA checkpoint.
|
||||||
|
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
|
||||||
|
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
|
||||||
|
* Gradient Configuration
|
||||||
|
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||||
|
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||||
|
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||||
|
* Resolution Configuration
|
||||||
|
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
|
||||||
|
* `--num_frames`: Number of frames for video (video generation models only).
|
||||||
|
* Stable Diffusion XL Specific Parameters
|
||||||
|
* `--tokenizer_path`: Path to the first tokenizer.
|
||||||
|
* `--tokenizer_2_path`: Path to the second tokenizer, defaults to `stabilityai/stable-diffusion-xl-base-1.0:tokenizer_2/`.
|
||||||
|
|
||||||
|
Example dataset download:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
[stable-diffusion-xl-base-1.0 training scripts](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)
|
||||||
|
|
||||||
|
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
138
docs/en/Model_Details/Stable-Diffusion.md
Normal file
138
docs/en/Model_Details/Stable-Diffusion.md
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
# Stable Diffusion
|
||||||
|
|
||||||
|
Stable Diffusion is an open-source diffusion-based text-to-image generation model developed by Stability AI, supporting 512x512 resolution text-to-image generation.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Before performing model inference and training, please install DiffSynth-Studio first.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Running the following code will quickly load the [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 2GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
|
||||||
|
|
||||||
|
## Model Inference
|
||||||
|
|
||||||
|
The model is loaded via `StableDiffusionPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||||
|
|
||||||
|
The input parameters for `StableDiffusionPipeline` inference include:
|
||||||
|
|
||||||
|
* `prompt`: Text prompt.
|
||||||
|
* `negative_prompt`: Negative prompt, defaults to an empty string.
|
||||||
|
* `cfg_scale`: Classifier-Free Guidance scale factor, default 7.5.
|
||||||
|
* `height`: Output image height, default 512.
|
||||||
|
* `width`: Output image width, default 512.
|
||||||
|
* `seed`: Random seed, defaults to a random value if not set.
|
||||||
|
* `rand_device`: Noise generation device, defaults to "cpu".
|
||||||
|
* `num_inference_steps`: Number of inference steps, default 50.
|
||||||
|
* `eta`: DDIM scheduler eta parameter, default 0.0.
|
||||||
|
* `guidance_rescale`: Guidance rescale factor, default 0.0.
|
||||||
|
* `progress_bar_cmd`: Progress bar callback function.
|
||||||
|
|
||||||
|
## Model Training
|
||||||
|
|
||||||
|
Models in the stable_diffusion series are trained via `examples/stable_diffusion/model_training/train.py`. The script parameters include:
|
||||||
|
|
||||||
|
* General Training Parameters
|
||||||
|
* Dataset Configuration
|
||||||
|
* `--dataset_base_path`: Root directory of the dataset.
|
||||||
|
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||||
|
* `--dataset_repeat`: Number of dataset repeats per epoch.
|
||||||
|
* `--dataset_num_workers`: Number of processes per DataLoader.
|
||||||
|
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
|
||||||
|
* Model Loading Configuration
|
||||||
|
* `--model_paths`: Paths to load models from, in JSON format.
|
||||||
|
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
|
||||||
|
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
|
||||||
|
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
|
||||||
|
* Basic Training Configuration
|
||||||
|
* `--learning_rate`: Learning rate.
|
||||||
|
* `--num_epochs`: Number of epochs.
|
||||||
|
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||||
|
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
|
||||||
|
* `--weight_decay`: Weight decay magnitude.
|
||||||
|
* `--task`: Training task, defaults to `sft`.
|
||||||
|
* Output Configuration
|
||||||
|
* `--output_path`: Path to save the model.
|
||||||
|
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
|
||||||
|
* `--save_steps`: Interval in training steps to save the model.
|
||||||
|
* LoRA Configuration
|
||||||
|
* `--lora_base_model`: Which model to add LoRA to.
|
||||||
|
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||||
|
* `--lora_rank`: Rank of LoRA.
|
||||||
|
* `--lora_checkpoint`: Path to LoRA checkpoint.
|
||||||
|
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
|
||||||
|
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
|
||||||
|
* Gradient Configuration
|
||||||
|
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||||
|
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||||
|
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||||
|
* Resolution Configuration
|
||||||
|
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
|
||||||
|
* `--num_frames`: Number of frames for video (video generation models only).
|
||||||
|
* Stable Diffusion Specific Parameters
|
||||||
|
* `--tokenizer_path`: Tokenizer path, defaults to `AI-ModelScope/stable-diffusion-v1-5:tokenizer/`.
|
||||||
|
|
||||||
|
Example dataset download:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
[stable-diffusion-v1-5 training scripts](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)
|
||||||
|
|
||||||
|
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
@@ -32,6 +32,8 @@ Welcome to DiffSynth-Studio's Documentation
|
|||||||
Model_Details/LTX-2
|
Model_Details/LTX-2
|
||||||
Model_Details/ERNIE-Image
|
Model_Details/ERNIE-Image
|
||||||
Model_Details/JoyAI-Image
|
Model_Details/JoyAI-Image
|
||||||
|
Model_Details/Stable-Diffusion
|
||||||
|
Model_Details/Stable-Diffusion-XL
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|||||||
141
docs/zh/Model_Details/Stable-Diffusion-XL.md
Normal file
141
docs/zh/Model_Details/Stable-Diffusion-XL.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Stable Diffusion XL
|
||||||
|
|
||||||
|
Stable Diffusion XL (SDXL) 是由 Stability AI 开发的开源扩散式文本到图像生成模型,支持 1024x1024 分辨率的高质量文本到图像生成,采用双文本编码器(CLIP-L + CLIP-bigG)架构。
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 6GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型总览
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
|
||||||
|
|
||||||
|
## 模型推理
|
||||||
|
|
||||||
|
模型通过 `StableDiffusionXLPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||||
|
|
||||||
|
`StableDiffusionXLPipeline` 的推理输入参数包括:
|
||||||
|
|
||||||
|
* `prompt`: 文本提示词。
|
||||||
|
* `negative_prompt`: 负面提示词,默认为空字符串。
|
||||||
|
* `cfg_scale`: Classifier-Free Guidance 缩放系数,默认 5.0。
|
||||||
|
* `height`: 输出图像高度,默认 1024。
|
||||||
|
* `width`: 输出图像宽度,默认 1024。
|
||||||
|
* `seed`: 随机种子,默认不设置时使用随机种子。
|
||||||
|
* `rand_device`: 噪声生成设备,默认 "cpu"。
|
||||||
|
* `num_inference_steps`: 推理步数,默认 50。
|
||||||
|
* `guidance_rescale`: Guidance rescale 系数,默认 0.0。
|
||||||
|
* `progress_bar_cmd`: 进度条回调函数。
|
||||||
|
|
||||||
|
> `StableDiffusionXLPipeline` 需要双 tokenizer 配置(`tokenizer_config` 和 `tokenizer_2_config`),分别对应 CLIP-L 和 CLIP-bigG 文本编码器。
|
||||||
|
|
||||||
|
## 模型训练
|
||||||
|
|
||||||
|
stable_diffusion_xl 系列模型通过 `examples/stable_diffusion_xl/model_training/train.py` 进行训练,脚本的参数包括:
|
||||||
|
|
||||||
|
* 通用训练参数
|
||||||
|
* 数据集基础配置
|
||||||
|
* `--dataset_base_path`: 数据集的根目录。
|
||||||
|
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||||
|
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||||
|
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
|
||||||
|
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||||
|
* 模型加载配置
|
||||||
|
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||||
|
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
|
||||||
|
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
|
||||||
|
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
|
||||||
|
* 训练基础配置
|
||||||
|
* `--learning_rate`: 学习率。
|
||||||
|
* `--num_epochs`: 轮数(Epoch)。
|
||||||
|
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||||
|
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
|
||||||
|
* `--weight_decay`: 权重衰减大小。
|
||||||
|
* `--task`: 训练任务,默认为 `sft`。
|
||||||
|
* 输出配置
|
||||||
|
* `--output_path`: 模型保存路径。
|
||||||
|
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||||
|
* `--save_steps`: 保存模型的训练步数间隔。
|
||||||
|
* LoRA 配置
|
||||||
|
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||||
|
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||||
|
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||||
|
* `--lora_checkpoint`: LoRA 检查点的路径。
|
||||||
|
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
|
||||||
|
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||||
|
* 梯度配置
|
||||||
|
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||||
|
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||||
|
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||||
|
* 分辨率配置
|
||||||
|
* `--height`: 图像/视频的高度。留空启用动态分辨率。
|
||||||
|
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
|
||||||
|
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
|
||||||
|
* `--num_frames`: 视频的帧数(仅视频生成模型)。
|
||||||
|
* Stable Diffusion XL 专有参数
|
||||||
|
* `--tokenizer_path`: 第一个 Tokenizer 路径。
|
||||||
|
* `--tokenizer_2_path`: 第二个 Tokenizer 路径,默认为 `stabilityai/stable-diffusion-xl-base-1.0:tokenizer_2/`。
|
||||||
|
|
||||||
|
样例数据集下载:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
[stable-diffusion-xl-base-1.0 训练脚本](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)
|
||||||
|
|
||||||
|
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
|
||||||
138
docs/zh/Model_Details/Stable-Diffusion.md
Normal file
138
docs/zh/Model_Details/Stable-Diffusion.md
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
# Stable Diffusion
|
||||||
|
|
||||||
|
Stable Diffusion 是由 Stability AI 开发的开源扩散式文本到图像生成模型,支持 512x512 分辨率的文本到图像生成。
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 2GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型总览
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
|
||||||
|
|
||||||
|
## 模型推理
|
||||||
|
|
||||||
|
模型通过 `StableDiffusionPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||||
|
|
||||||
|
`StableDiffusionPipeline` 的推理输入参数包括:
|
||||||
|
|
||||||
|
* `prompt`: 文本提示词。
|
||||||
|
* `negative_prompt`: 负面提示词,默认为空字符串。
|
||||||
|
* `cfg_scale`: Classifier-Free Guidance 缩放系数,默认 7.5。
|
||||||
|
* `height`: 输出图像高度,默认 512。
|
||||||
|
* `width`: 输出图像宽度,默认 512。
|
||||||
|
* `seed`: 随机种子,默认不设置时使用随机种子。
|
||||||
|
* `rand_device`: 噪声生成设备,默认 "cpu"。
|
||||||
|
* `num_inference_steps`: 推理步数,默认 50。
|
||||||
|
* `eta`: DDIM 调度器的 eta 参数,默认 0.0。
|
||||||
|
* `guidance_rescale`: Guidance rescale 系数,默认 0.0。
|
||||||
|
* `progress_bar_cmd`: 进度条回调函数。
|
||||||
|
|
||||||
|
## 模型训练
|
||||||
|
|
||||||
|
stable_diffusion 系列模型通过 `examples/stable_diffusion/model_training/train.py` 进行训练,脚本的参数包括:
|
||||||
|
|
||||||
|
* 通用训练参数
|
||||||
|
* 数据集基础配置
|
||||||
|
* `--dataset_base_path`: 数据集的根目录。
|
||||||
|
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||||
|
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||||
|
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
|
||||||
|
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||||
|
* 模型加载配置
|
||||||
|
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||||
|
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
|
||||||
|
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
|
||||||
|
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
|
||||||
|
* 训练基础配置
|
||||||
|
* `--learning_rate`: 学习率。
|
||||||
|
* `--num_epochs`: 轮数(Epoch)。
|
||||||
|
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||||
|
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
|
||||||
|
* `--weight_decay`: 权重衰减大小。
|
||||||
|
* `--task`: 训练任务,默认为 `sft`。
|
||||||
|
* 输出配置
|
||||||
|
* `--output_path`: 模型保存路径。
|
||||||
|
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||||
|
* `--save_steps`: 保存模型的训练步数间隔。
|
||||||
|
* LoRA 配置
|
||||||
|
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||||
|
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||||
|
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||||
|
* `--lora_checkpoint`: LoRA 检查点的路径。
|
||||||
|
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
|
||||||
|
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||||
|
* 梯度配置
|
||||||
|
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||||
|
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||||
|
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||||
|
* 分辨率配置
|
||||||
|
* `--height`: 图像/视频的高度。留空启用动态分辨率。
|
||||||
|
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
|
||||||
|
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
|
||||||
|
* `--num_frames`: 视频的帧数(仅视频生成模型)。
|
||||||
|
* Stable Diffusion 专有参数
|
||||||
|
* `--tokenizer_path`: Tokenizer 路径,默认为 `AI-ModelScope/stable-diffusion-v1-5:tokenizer/`。
|
||||||
|
|
||||||
|
样例数据集下载:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
[stable-diffusion-v1-5 训练脚本](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)
|
||||||
|
|
||||||
|
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
|
||||||
@@ -32,6 +32,8 @@
|
|||||||
Model_Details/LTX-2
|
Model_Details/LTX-2
|
||||||
Model_Details/ERNIE-Image
|
Model_Details/ERNIE-Image
|
||||||
Model_Details/JoyAI-Image
|
Model_Details/JoyAI-Image
|
||||||
|
Model_Details/Stable-Diffusion
|
||||||
|
Model_Details/Stable-Diffusion-XL
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|||||||
@@ -1,331 +0,0 @@
|
|||||||
import importlib, inspect, pkgutil, traceback, torch, os, re, typing
|
|
||||||
from typing import Union, List, Optional, Tuple, Iterable, Dict, Literal
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from diffsynth.utils.data import VideoData
|
|
||||||
import streamlit as st
|
|
||||||
from diffsynth import ModelConfig
|
|
||||||
from diffsynth.diffusion.base_pipeline import ControlNetInput
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
st.set_page_config(layout="wide")
|
|
||||||
|
|
||||||
class StreamlitTqdmWrapper:
|
|
||||||
"""Wrapper class that combines tqdm and streamlit progress bar"""
|
|
||||||
def __init__(self, iterable, st_progress_bar=None):
|
|
||||||
self.iterable = iterable
|
|
||||||
self.st_progress_bar = st_progress_bar
|
|
||||||
self.tqdm_bar = tqdm(iterable)
|
|
||||||
self.total = len(iterable) if hasattr(iterable, '__len__') else None
|
|
||||||
self.current = 0
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
for item in self.tqdm_bar:
|
|
||||||
if self.st_progress_bar is not None and self.total is not None:
|
|
||||||
self.current += 1
|
|
||||||
self.st_progress_bar.progress(self.current / self.total)
|
|
||||||
yield item
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
if hasattr(self.tqdm_bar, '__exit__'):
|
|
||||||
self.tqdm_bar.__exit__(*args)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def catch_error(error_value):
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except Exception as e:
|
|
||||||
error_message = traceback.format_exc()
|
|
||||||
print(f"Error {error_value}:\n{error_message}")
|
|
||||||
|
|
||||||
def parse_model_configs_from_an_example(path):
|
|
||||||
model_configs = []
|
|
||||||
with open(path, "r") as f:
|
|
||||||
for code in f.readlines():
|
|
||||||
code = code.strip()
|
|
||||||
if not code.startswith("ModelConfig"):
|
|
||||||
continue
|
|
||||||
pairs = re.findall(r'(\w+)\s*=\s*["\']([^"\']+)["\']', code)
|
|
||||||
config_dict = {k: v for k, v in pairs}
|
|
||||||
model_configs.append(ModelConfig(model_id=config_dict["model_id"], origin_file_pattern=config_dict["origin_file_pattern"]))
|
|
||||||
return model_configs
|
|
||||||
|
|
||||||
def list_examples(path, keyword=None):
|
|
||||||
examples = []
|
|
||||||
if os.path.isdir(path):
|
|
||||||
for file_name in os.listdir(path):
|
|
||||||
examples.extend(list_examples(os.path.join(path, file_name), keyword=keyword))
|
|
||||||
elif path.endswith(".py"):
|
|
||||||
with open(path, "r") as f:
|
|
||||||
code = f.read()
|
|
||||||
if keyword is None or keyword in code:
|
|
||||||
examples.extend([path])
|
|
||||||
return examples
|
|
||||||
|
|
||||||
def parse_available_pipelines():
|
|
||||||
from diffsynth.diffusion.base_pipeline import BasePipeline
|
|
||||||
import diffsynth.pipelines as _pipelines_pkg
|
|
||||||
available_pipelines = {}
|
|
||||||
for _, name, _ in pkgutil.iter_modules(_pipelines_pkg.__path__):
|
|
||||||
with catch_error(f"Failed: import diffsynth.pipelines.{name}"):
|
|
||||||
mod = importlib.import_module(f"diffsynth.pipelines.{name}")
|
|
||||||
classes = {
|
|
||||||
cls_name: cls for cls_name, cls in inspect.getmembers(mod, inspect.isclass)
|
|
||||||
if issubclass(cls, BasePipeline) and cls is not BasePipeline and cls.__module__ == mod.__name__
|
|
||||||
}
|
|
||||||
available_pipelines.update(classes)
|
|
||||||
return available_pipelines
|
|
||||||
|
|
||||||
def parse_available_examples(path, available_pipelines):
|
|
||||||
available_examples = {}
|
|
||||||
for pipeline_name in available_pipelines:
|
|
||||||
examples = ["None"] + list_examples(path, keyword=f"{pipeline_name}.from_pretrained")
|
|
||||||
available_examples[pipeline_name] = examples
|
|
||||||
return available_examples
|
|
||||||
|
|
||||||
def draw_selectbox(label, options, option_map, value=None, disabled=False):
|
|
||||||
default_index = 0 if value is None else tuple(options).index([option for option in option_map if option_map[option]==value][0])
|
|
||||||
option = st.selectbox(label=label, options=tuple(options), index=default_index, disabled=disabled)
|
|
||||||
return option_map.get(option)
|
|
||||||
|
|
||||||
def parse_params(fn):
|
|
||||||
params = []
|
|
||||||
for name, param in inspect.signature(fn).parameters.items():
|
|
||||||
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else None
|
|
||||||
default = param.default if param.default is not inspect.Parameter.empty else None
|
|
||||||
params.append({"name": name, "dtype": annotation, "value": default})
|
|
||||||
return params
|
|
||||||
|
|
||||||
def draw_model_config(model_config=None, key_suffix="", disabled=False):
|
|
||||||
with st.container(border=True):
|
|
||||||
if model_config is None:
|
|
||||||
model_config = ModelConfig()
|
|
||||||
path = st.text_input(label="path", key="path" + key_suffix, value=model_config.path, disabled=disabled)
|
|
||||||
col1, col2 = st.columns(2)
|
|
||||||
with col1:
|
|
||||||
model_id = st.text_input(label="model_id", key="model_id" + key_suffix, value=model_config.model_id, disabled=disabled)
|
|
||||||
with col2:
|
|
||||||
origin_file_pattern = st.text_input(label="origin_file_pattern", key="origin_file_pattern" + key_suffix, value=model_config.origin_file_pattern, disabled=disabled)
|
|
||||||
model_config = ModelConfig(
|
|
||||||
path=None if path == "" else path,
|
|
||||||
model_id=model_id,
|
|
||||||
origin_file_pattern=origin_file_pattern,
|
|
||||||
)
|
|
||||||
return model_config
|
|
||||||
|
|
||||||
def draw_multi_model_config(name="", value=None, disabled=False):
|
|
||||||
model_configs = []
|
|
||||||
with st.container(border=True):
|
|
||||||
st.markdown(name)
|
|
||||||
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
|
|
||||||
for i in range(num):
|
|
||||||
model_config = draw_model_config(key_suffix=f"_{name}_{i}", model_config=None if value is None else value[i], disabled=disabled)
|
|
||||||
model_configs.append(model_config)
|
|
||||||
return model_configs
|
|
||||||
|
|
||||||
def draw_single_model_config(name="", value=None, disabled=False):
|
|
||||||
with st.container(border=True):
|
|
||||||
st.markdown(name)
|
|
||||||
model_config = draw_model_config(value, key_suffix=f"_{name}", disabled=disabled)
|
|
||||||
return model_config
|
|
||||||
|
|
||||||
def draw_multi_images(name="", value=None, disabled=False):
|
|
||||||
images = []
|
|
||||||
with st.container(border=True):
|
|
||||||
st.markdown(name)
|
|
||||||
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
|
|
||||||
for i in range(num):
|
|
||||||
image = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], key=f"{name}_{i}", disabled=disabled)
|
|
||||||
if image is not None: images.append(Image.open(image))
|
|
||||||
return images
|
|
||||||
|
|
||||||
def draw_multi_elements(st_element, name="", value=None, disabled=False, kwargs=None):
|
|
||||||
if kwargs is None:
|
|
||||||
kwargs = {}
|
|
||||||
elements = []
|
|
||||||
with st.container(border=True):
|
|
||||||
st.markdown(name)
|
|
||||||
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
|
|
||||||
for i in range(num):
|
|
||||||
element = st_element(name, key=f"{name}_{i}", disabled=disabled, value=None if value is None else value[i], **kwargs)
|
|
||||||
elements.append(element)
|
|
||||||
return elements
|
|
||||||
|
|
||||||
def draw_controlnet_input(name="", value=None, disabled=False):
|
|
||||||
with st.container(border=True):
|
|
||||||
st.markdown(name)
|
|
||||||
controlnet_id = st.number_input("controlnet_id", value=0, min_value=0, max_value=20, step=1, key=f"{name}_controlnet_id")
|
|
||||||
scale = st.number_input("scale", value=1.0, min_value=0.0, max_value=10.0, key=f"{name}_scale")
|
|
||||||
image = st.file_uploader("image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_image")
|
|
||||||
if image is not None: image = Image.open(image)
|
|
||||||
inpaint_image = st.file_uploader("inpaint_image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_image")
|
|
||||||
if inpaint_image is not None: inpaint_image = Image.open(inpaint_image)
|
|
||||||
inpaint_mask = st.file_uploader("inpaint_mask", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_mask")
|
|
||||||
if inpaint_mask is not None: inpaint_mask = Image.open(inpaint_mask)
|
|
||||||
return ControlNetInput(controlnet_id=controlnet_id, scale=scale, image=image, inpaint_image=inpaint_image, inpaint_mask=inpaint_mask)
|
|
||||||
|
|
||||||
def draw_controlnet_inputs(name, value=None, disabled=False):
|
|
||||||
controlnet_inputs = []
|
|
||||||
with st.container(border=True):
|
|
||||||
st.markdown(name)
|
|
||||||
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
|
|
||||||
for i in range(num):
|
|
||||||
controlnet_input = draw_controlnet_input(name=f"{name}_{i}", value=None, disabled=disabled)
|
|
||||||
controlnet_inputs.append(controlnet_input)
|
|
||||||
return controlnet_inputs
|
|
||||||
|
|
||||||
def draw_ui_element(name, dtype, value):
|
|
||||||
unsupported_dtype = [
|
|
||||||
Dict[str, torch.Tensor],
|
|
||||||
torch.Tensor,
|
|
||||||
]
|
|
||||||
if dtype in unsupported_dtype:
|
|
||||||
return
|
|
||||||
if value is None:
|
|
||||||
with st.container(border=True):
|
|
||||||
enable = st.checkbox(f"Enable {name}", value=False)
|
|
||||||
ui = draw_ui_element_safely(name, dtype, value=value, disabled=not enable)
|
|
||||||
if enable:
|
|
||||||
return ui
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return draw_ui_element_safely(name, dtype, value)
|
|
||||||
|
|
||||||
def draw_video(name, value=None, disabled=False):
|
|
||||||
ui = st.file_uploader(name, type=["mp4"], disabled=disabled)
|
|
||||||
if ui is not None:
|
|
||||||
ui = VideoData(ui)
|
|
||||||
ui = [ui[i] for i in range(len(ui))]
|
|
||||||
return ui
|
|
||||||
|
|
||||||
def draw_ui_element_safely(name, dtype, value, disabled=False):
|
|
||||||
if dtype == torch.dtype:
|
|
||||||
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
|
||||||
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
|
|
||||||
elif dtype == Union[str, torch.device]:
|
|
||||||
option_map = {"cuda": "cuda", "cpu": "cpu"}
|
|
||||||
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
|
|
||||||
elif dtype == bool:
|
|
||||||
ui = st.checkbox(name, value=value, disabled=disabled)
|
|
||||||
elif dtype == ModelConfig:
|
|
||||||
ui = draw_single_model_config(name, value=value, disabled=disabled)
|
|
||||||
elif dtype in [list[ModelConfig], List[ModelConfig], Union[list[ModelConfig], ModelConfig, str]]:
|
|
||||||
if name == "model_configs" and "model_configs_from_example" in st.session_state:
|
|
||||||
model_configs = st.session_state["model_configs_from_example"]
|
|
||||||
del st.session_state["model_configs_from_example"]
|
|
||||||
ui = draw_multi_model_config(name, model_configs, disabled=disabled)
|
|
||||||
else:
|
|
||||||
ui = draw_multi_model_config(name, disabled=disabled)
|
|
||||||
elif dtype == str:
|
|
||||||
if "prompt" in name:
|
|
||||||
ui = st.text_area(name, value=value, height=3, disabled=disabled)
|
|
||||||
else:
|
|
||||||
ui = st.text_input(name, value=value, disabled=disabled)
|
|
||||||
elif dtype == float:
|
|
||||||
ui = st.number_input(name, value=value, disabled=disabled)
|
|
||||||
elif dtype == int:
|
|
||||||
ui = st.number_input(name, value=value, step=1, disabled=disabled)
|
|
||||||
elif dtype == Image.Image:
|
|
||||||
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
|
|
||||||
if ui is not None: ui = Image.open(ui)
|
|
||||||
elif dtype in [List[Image.Image], list[Image.Image], Union[list[Image.Image], Image.Image], Union[List[Image.Image], Image.Image]]:
|
|
||||||
if "video" in name:
|
|
||||||
ui = draw_video(name, value=value, disabled=disabled)
|
|
||||||
else:
|
|
||||||
ui = draw_multi_images(name, value=value, disabled=disabled)
|
|
||||||
elif dtype in [List[ControlNetInput], list[ControlNetInput]]:
|
|
||||||
ui = draw_controlnet_inputs(name, value=value, disabled=disabled)
|
|
||||||
elif dtype in [List[str], list[str]]:
|
|
||||||
ui = draw_multi_elements(st.text_input, name, value=value, disabled=disabled)
|
|
||||||
elif dtype in [List[float], list[float], Union[list[float], float], Union[List[float], float]]:
|
|
||||||
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled)
|
|
||||||
elif dtype in [List[int], list[int]]:
|
|
||||||
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled, kwargs={"step": 1})
|
|
||||||
elif dtype in [List[List[Image.Image]], list[list[Image.Image]]]:
|
|
||||||
ui = draw_multi_elements(draw_video, name, value=value, disabled=disabled)
|
|
||||||
elif dtype in [tuple[int, int], Tuple[int, int]]:
|
|
||||||
with st.container(border=True):
|
|
||||||
st.markdown(name)
|
|
||||||
ui = (st.text_input(f"{name}_0", value=value[0], disabled=disabled), st.text_input(f"{name}_1", value=value[1], disabled=disabled))
|
|
||||||
elif isinstance(dtype, typing._LiteralGenericAlias):
|
|
||||||
with st.container(border=True):
|
|
||||||
st.markdown(f"{name} ({dtype})")
|
|
||||||
ui = st.text_input(name, value=value, disabled=disabled, label_visibility="hidden")
|
|
||||||
elif dtype is None:
|
|
||||||
if name == "progress_bar_cmd":
|
|
||||||
ui = value
|
|
||||||
else:
|
|
||||||
st.markdown(f"(`{name}` is not not configurable in WebUI). dtype: `{dtype}`.")
|
|
||||||
ui = value
|
|
||||||
return ui
|
|
||||||
|
|
||||||
|
|
||||||
def launch_webui():
|
|
||||||
input_col, output_col = st.columns(2)
|
|
||||||
with input_col:
|
|
||||||
if "available_pipelines" not in st.session_state:
|
|
||||||
st.session_state["available_pipelines"] = parse_available_pipelines()
|
|
||||||
if "available_examples" not in st.session_state:
|
|
||||||
st.session_state["available_examples"] = parse_available_examples("./examples", st.session_state["available_pipelines"])
|
|
||||||
|
|
||||||
with st.expander("Pipeline", expanded=True):
|
|
||||||
pipeline_class = draw_selectbox("Pipeline Class", st.session_state["available_pipelines"].keys(), st.session_state["available_pipelines"], value=st.session_state["available_pipelines"]["ZImagePipeline"])
|
|
||||||
example = st.selectbox("Parse model configs from an example (optional)", st.session_state["available_examples"][pipeline_class.__name__])
|
|
||||||
|
|
||||||
# Clear if pipeline is changed
|
|
||||||
if "prev_pipeline_class" in st.session_state and st.session_state["prev_pipeline_class"] != pipeline_class:
|
|
||||||
if "pipeline_class" in st.session_state: del st.session_state["pipeline_class"]
|
|
||||||
if "model_configs_from_example" in st.session_state: del st.session_state["model_configs_from_example"]
|
|
||||||
if "prev_example" in st.session_state and st.session_state["prev_example"] != example:
|
|
||||||
if "model_configs_from_example" in st.session_state: del st.session_state["model_configs_from_example"]
|
|
||||||
st.session_state["prev_pipeline_class"] = pipeline_class
|
|
||||||
st.session_state["prev_example"] = example
|
|
||||||
|
|
||||||
if st.button("Step 1: Parse Pipeline", type="primary"):
|
|
||||||
st.session_state["pipeline_class"] = pipeline_class
|
|
||||||
if example != "None":
|
|
||||||
st.session_state["model_configs_from_example"] = parse_model_configs_from_an_example(example)
|
|
||||||
|
|
||||||
if "pipeline_class" not in st.session_state:
|
|
||||||
return
|
|
||||||
with st.expander("Model", expanded=True):
|
|
||||||
input_params = {}
|
|
||||||
params = parse_params(pipeline_class.from_pretrained)
|
|
||||||
for param in params:
|
|
||||||
input_params[param["name"]] = draw_ui_element(**param)
|
|
||||||
if st.button("Step 2: Load Models", type="primary"):
|
|
||||||
with st.spinner("Loading models", show_time=True):
|
|
||||||
if "pipe" in st.session_state:
|
|
||||||
del st.session_state["pipe"]
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
st.session_state["pipe"] = pipeline_class.from_pretrained(**input_params)
|
|
||||||
|
|
||||||
if "pipe" not in st.session_state:
|
|
||||||
return
|
|
||||||
with st.expander("Input", expanded=True):
|
|
||||||
pipe = st.session_state["pipe"]
|
|
||||||
input_params = {}
|
|
||||||
params = parse_params(pipeline_class.__call__)
|
|
||||||
for param in params:
|
|
||||||
if param["name"] in ["self"]:
|
|
||||||
continue
|
|
||||||
input_params[param["name"]] = draw_ui_element(**param)
|
|
||||||
|
|
||||||
with output_col:
|
|
||||||
if st.button("Step 3: Generate", type="primary"):
|
|
||||||
if "progress_bar_cmd" in input_params:
|
|
||||||
input_params["progress_bar_cmd"] = lambda iterable: StreamlitTqdmWrapper(iterable, st.progress(0))
|
|
||||||
result = pipe(**input_params)
|
|
||||||
st.session_state["result"] = result
|
|
||||||
|
|
||||||
if "result" in st.session_state:
|
|
||||||
result = st.session_state["result"]
|
|
||||||
if isinstance(result, Image.Image):
|
|
||||||
st.image(result)
|
|
||||||
else:
|
|
||||||
print(f"unsupported result format: {result}")
|
|
||||||
|
|
||||||
launch_webui()
|
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/stable-diffusion-v1-5/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
|
||||||
|
accelerate launch examples/stable_diffusion/model_training/train.py \
|
||||||
|
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5 \
|
||||||
|
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5/metadata.csv \
|
||||||
|
--height 512 \
|
||||||
|
--width 512 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--trainable_models "unet" \
|
||||||
|
--remove_prefix_in_ckpt "pipe.unet." \
|
||||||
|
--output_path "./models/train/stable-diffusion-v1-5_full" \
|
||||||
|
--use_gradient_checkpointing
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/stable-diffusion-v1-5/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
|
||||||
|
accelerate launch examples/stable_diffusion/model_training/train.py \
|
||||||
|
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5 \
|
||||||
|
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5/metadata.csv \
|
||||||
|
--height 512 \
|
||||||
|
--width 512 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.unet." \
|
||||||
|
--output_path "./models/train/stable-diffusion-v1-5_lora" \
|
||||||
|
--lora_base_model "unet" \
|
||||||
|
--lora_target_modules "" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing
|
||||||
142
examples/stable_diffusion/model_training/train.py
Normal file
142
examples/stable_diffusion/model_training/train.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import torch, os, argparse, accelerate
|
||||||
|
from diffsynth.core import UnifiedDataset
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
|
||||||
|
from diffsynth.diffusion import *
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionTrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_paths=None, model_id_with_origin_paths=None,
|
||||||
|
tokenizer_path=None,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
|
use_gradient_checkpointing=True,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
extra_inputs=None,
|
||||||
|
fp8_models=None,
|
||||||
|
offload_models=None,
|
||||||
|
device="cpu",
|
||||||
|
task="sft",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Load models
|
||||||
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||||
|
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"))
|
||||||
|
self.pipe = StableDiffusionPipeline.from_pretrained(torch_dtype=torch.float32, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||||
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
|
||||||
|
# Training mode
|
||||||
|
self.switch_pipe_to_training_mode(
|
||||||
|
self.pipe, trainable_models,
|
||||||
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||||
|
preset_lora_path, preset_lora_model,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other configs
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
|
self.fp8_models = fp8_models
|
||||||
|
self.task = task
|
||||||
|
self.task_to_loss = {
|
||||||
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
|
"direct_distill:data_process": lambda pipe, *args: args,
|
||||||
|
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_pipeline_inputs(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
inputs_shared = {
|
||||||
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"rand_device": self.pipe.device,
|
||||||
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
|
}
|
||||||
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
def forward(self, data, inputs=None):
|
||||||
|
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||||
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||||
|
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser = add_general_config(parser)
|
||||||
|
parser = add_image_size_config(parser)
|
||||||
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
accelerator = accelerate.Accelerator(
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||||
|
)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=args.dataset_repeat,
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=32,
|
||||||
|
width_division_factor=32,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = StableDiffusionTrainingModule(
|
||||||
|
model_paths=args.model_paths,
|
||||||
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
tokenizer_path=args.tokenizer_path,
|
||||||
|
trainable_models=args.trainable_models,
|
||||||
|
lora_base_model=args.lora_base_model,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_checkpoint=args.lora_checkpoint,
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
preset_lora_model=args.preset_lora_model,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
extra_inputs=args.extra_inputs,
|
||||||
|
fp8_models=args.fp8_models,
|
||||||
|
offload_models=args.offload_models,
|
||||||
|
task=args.task,
|
||||||
|
device=accelerator.device,
|
||||||
|
)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
args.output_path,
|
||||||
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
|
)
|
||||||
|
launcher_map = {
|
||||||
|
"sft:data_process": launch_data_process_task,
|
||||||
|
"direct_distill:data_process": launch_data_process_task,
|
||||||
|
"sft": launch_training_task,
|
||||||
|
"sft:train": launch_training_task,
|
||||||
|
"direct_distill": launch_training_task,
|
||||||
|
"direct_distill:train": launch_training_task,
|
||||||
|
}
|
||||||
|
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
|
||||||
|
from diffsynth.core import load_state_dict
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
state_dict = load_state_dict("./models/train/stable-diffusion-v1-5_full/epoch-1.safetensors", torch_dtype=torch.float32)
|
||||||
|
pipe.unet.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a dog",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image_stable-diffusion-v1-5_full.jpg")
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.unet, "models/train/stable-diffusion-v1-5_lora/epoch-4.safetensors")
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a dog",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image_stable-diffusion-v1-5.jpg")
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/stable-diffusion-xl-base-1.0/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
|
||||||
|
accelerate launch examples/stable_diffusion_xl/model_training/train.py \
|
||||||
|
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0 \
|
||||||
|
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0/metadata.csv \
|
||||||
|
--height 1024 \
|
||||||
|
--width 1024 \
|
||||||
|
--dataset_repeat 10 \
|
||||||
|
--model_id_with_origin_paths "stabilityai/stable-diffusion-xl-base-1.0:text_encoder/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:text_encoder_2/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:unet/diffusion_pytorch_model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--trainable_models "unet" \
|
||||||
|
--remove_prefix_in_ckpt "pipe.unet." \
|
||||||
|
--output_path "./models/train/stable-diffusion-xl-base-1.0_full" \
|
||||||
|
--use_gradient_checkpointing
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/stable-diffusion-xl-base-1.0/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
|
||||||
|
accelerate launch examples/stable_diffusion_xl/model_training/train.py \
|
||||||
|
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0 \
|
||||||
|
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0/metadata.csv \
|
||||||
|
--height 1024 \
|
||||||
|
--width 1024 \
|
||||||
|
--dataset_repeat 10 \
|
||||||
|
--model_id_with_origin_paths "stabilityai/stable-diffusion-xl-base-1.0:text_encoder/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:text_encoder_2/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:unet/diffusion_pytorch_model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.unet." \
|
||||||
|
--output_path "./models/train/stable-diffusion-xl-base-1.0_lora" \
|
||||||
|
--lora_base_model "unet" \
|
||||||
|
--lora_target_modules "" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing
|
||||||
144
examples/stable_diffusion_xl/model_training/train.py
Normal file
144
examples/stable_diffusion_xl/model_training/train.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
import torch, os, argparse, accelerate
|
||||||
|
from diffsynth.core import UnifiedDataset
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
|
||||||
|
from diffsynth.diffusion import *
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLTrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_paths=None, model_id_with_origin_paths=None,
|
||||||
|
tokenizer_path=None,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
|
use_gradient_checkpointing=True,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
extra_inputs=None,
|
||||||
|
fp8_models=None,
|
||||||
|
offload_models=None,
|
||||||
|
device="cpu",
|
||||||
|
task="sft",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Load models
|
||||||
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||||
|
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"))
|
||||||
|
tokenizer_2_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"))
|
||||||
|
self.pipe = StableDiffusionXLPipeline.from_pretrained(torch_dtype=torch.float32, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, tokenizer_2_config=tokenizer_2_config)
|
||||||
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
|
||||||
|
# Training mode
|
||||||
|
self.switch_pipe_to_training_mode(
|
||||||
|
self.pipe, trainable_models,
|
||||||
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||||
|
preset_lora_path, preset_lora_model,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other configs
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
|
self.fp8_models = fp8_models
|
||||||
|
self.task = task
|
||||||
|
self.task_to_loss = {
|
||||||
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
|
"direct_distill:data_process": lambda pipe, *args: args,
|
||||||
|
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_pipeline_inputs(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
inputs_shared = {
|
||||||
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"rand_device": self.pipe.device,
|
||||||
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
|
}
|
||||||
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
def forward(self, data, inputs=None):
|
||||||
|
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||||
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||||
|
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser = add_general_config(parser)
|
||||||
|
parser = add_image_size_config(parser)
|
||||||
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||||
|
parser.add_argument("--tokenizer_2_path", type=str, default=None, help="Path to tokenizer 2.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
accelerator = accelerate.Accelerator(
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||||
|
)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=args.dataset_repeat,
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=32,
|
||||||
|
width_division_factor=32,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = StableDiffusionXLTrainingModule(
|
||||||
|
model_paths=args.model_paths,
|
||||||
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
tokenizer_path=args.tokenizer_path,
|
||||||
|
trainable_models=args.trainable_models,
|
||||||
|
lora_base_model=args.lora_base_model,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_checkpoint=args.lora_checkpoint,
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
preset_lora_model=args.preset_lora_model,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
extra_inputs=args.extra_inputs,
|
||||||
|
fp8_models=args.fp8_models,
|
||||||
|
offload_models=args.offload_models,
|
||||||
|
task=args.task,
|
||||||
|
device=accelerator.device,
|
||||||
|
)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
args.output_path,
|
||||||
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
|
)
|
||||||
|
launcher_map = {
|
||||||
|
"sft:data_process": launch_data_process_task,
|
||||||
|
"direct_distill:data_process": launch_data_process_task,
|
||||||
|
"sft": launch_training_task,
|
||||||
|
"sft:train": launch_training_task,
|
||||||
|
"direct_distill": launch_training_task,
|
||||||
|
"direct_distill:train": launch_training_task,
|
||||||
|
}
|
||||||
|
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
|
||||||
|
from diffsynth.core import load_state_dict
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
)
|
||||||
|
state_dict = load_state_dict("./models/train/stable-diffusion-xl-base-1.0_full/epoch-1.safetensors", torch_dtype=torch.float32)
|
||||||
|
pipe.unet.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a dog",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=7.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image_stable-diffusion-xl-base-1.0_full.jpg")
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.unet, "models/train/stable-diffusion-xl-base-1.0_lora/epoch-4.safetensors")
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a dog",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=7.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image_stable-diffusion-xl-base-1.0.jpg")
|
||||||
Reference in New Issue
Block a user