Compare commits

..

8 Commits
main ... sd

Author SHA1 Message Date
mi804
d5934719f8 style fix 2026-04-24 15:55:34 +08:00
mi804
54345f8678 sd docs 2026-04-24 15:41:13 +08:00
mi804
2d7d5137ea add full training 2026-04-24 15:11:34 +08:00
Artiprocher
3799bdc23a update sd training scripts 2026-04-24 14:30:09 +08:00
mi804
5cdab9ed01 sd and sdxl training 2026-04-24 10:19:58 +08:00
mi804
a8a0f082bb sdxl pipeline 2026-04-23 19:39:05 +08:00
mi804
9453700a30 sdxl modelcode 2026-04-23 18:26:25 +08:00
mi804
82e482286c sd 2026-04-23 17:35:24 +08:00
47 changed files with 4933 additions and 459 deletions

125
README.md
View File

@@ -34,6 +34,8 @@ We believe that a well-developed open-source code framework can lower the thresh
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **April 24, 2026** We add support for Stable Diffusion v1.5 and SDXL, including inference, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/Stable-Diffusion.md), [documentation](/docs/en/Model_Details/Stable-Diffusion-XL.md) and [example code](/examples/stable_diffusion/).
- **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/).
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
@@ -299,6 +301,129 @@ Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image
</details>
#### Stable Diffusion: [/docs/en/Model_Details/Stable-Diffusion.md](/docs/en/Model_Details/Stable-Diffusion.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 2GB VRAM.
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for Stable Diffusion is available at: [/examples/stable_diffusion/](/examples/stable_diffusion/)
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
</details>
#### Stable Diffusion XL: [/docs/en/Model_Details/Stable-Diffusion-XL.md](/docs/en/Model_Details/Stable-Diffusion-XL.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 6GB VRAM.
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionXLPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars",
negative_prompt="",
cfg_scale=5.0,
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
)
image.save("image.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for Stable Diffusion XL is available at: [/examples/stable_diffusion_xl/](/examples/stable_diffusion_xl/)
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
</details>
#### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md)
<details>

View File

@@ -34,6 +34,8 @@ DiffSynth 目前包括两个开源项目:
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2026年4月24日** 我们新增对 Stable Diffusion v1.5 和 SDXL 的支持,包括推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/Stable-Diffusion.md)和[示例代码](/examples/stable_diffusion/)。
- **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
@@ -299,6 +301,129 @@ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
</details>
#### Stable Diffusion[/docs/zh/Model_Details/Stable-Diffusion.md](/docs/zh/Model_Details/Stable-Diffusion.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) 模型并进行推理。显存管理已启用,框架会自动根据剩余显存控制模型参数的加载,最低 2GB 显存即可运行。
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image.jpg")
```
</details>
<details>
<summary>示例代码</summary>
Stable Diffusion 的示例代码位于:[/examples/stable_diffusion/](/examples/stable_diffusion/)
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|-|-|-|-|-|-|-|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
</details>
#### Stable Diffusion XL[/docs/zh/Model_Details/Stable-Diffusion-XL.md](/docs/zh/Model_Details/Stable-Diffusion-XL.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) 模型并进行推理。显存管理已启用,框架会自动根据剩余显存控制模型参数的加载,最低 6GB 显存即可运行。
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionXLPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars",
negative_prompt="",
cfg_scale=5.0,
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
)
image.save("image.jpg")
```
</details>
<details>
<summary>示例代码</summary>
Stable Diffusion XL 的示例代码位于:[/examples/stable_diffusion_xl/](/examples/stable_diffusion_xl/)
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|-|-|-|-|-|-|-|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
</details>
#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
<details>

View File

@@ -42,7 +42,6 @@ qwen_image_series = [
"model_hash": "5722b5c873720009de96422993b15682",
"model_name": "dinov3_image_encoder",
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.dino_v3.DINOv3StateDictConverter",
},
{
# Example:
@@ -901,6 +900,61 @@ mova_series = [
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
},
]
stable_diffusion_xl_series = [
{
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
"model_name": "stable_diffusion_xl_unet",
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
"extra_kwargs": {"attention_head_dim": [5, 10, 20], "transformer_layers_per_block": [1, 2, 10], "use_linear_projection": True, "addition_embed_type": "text_time", "addition_time_embed_dim": 256, "projection_class_embeddings_input_dim": 2816},
},
{
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
"model_hash": "98cc34ccc5b54ae0e56bdea8688dcd5a",
"model_name": "stable_diffusion_xl_text_encoder",
"model_class": "diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_xl_text_encoder.SDXLTextEncoder2StateDictConverter",
},
{
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
"model_name": "stable_diffusion_text_encoder",
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "13115dd45a6e1c39860f91ab073b8a78",
"model_name": "stable_diffusion_xl_vae",
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter",
"extra_kwargs": {"scaling_factor": 0.13025, "sample_size": 1024, "force_upcast": True},
},
]
stable_diffusion_series = [
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "ffd1737ae9df7fd43f5fbed653bdad67",
"model_name": "stable_diffusion_text_encoder",
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "f86d5683ed32433be8ca69969c67ba69",
"model_name": "stable_diffusion_vae",
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter",
},
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
"model_hash": "025a4b86a84829399d89f613e580757b",
"model_name": "stable_diffusion_unet",
"model_class": "diffsynth.models.stable_diffusion_unet.UNet2DConditionModel",
},
]
joyai_image_series = [
{
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth")
@@ -917,4 +971,4 @@ joyai_image_series = [
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series
MODEL_CONFIGS = stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series

View File

@@ -295,6 +295,45 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.stable_diffusion_unet.UNet2DConditionModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.stable_diffusion_vae.StableDiffusionVAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.stable_diffusion_vae.Upsample2D": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.stable_diffusion_vae.Downsample2D": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
},
}
def QwenImageTextEncoder_Module_Map_Updater():

View File

@@ -0,0 +1,107 @@
import torch, math
class DDIMScheduler():
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
self.num_train_timesteps = num_train_timesteps
if beta_schedule == "scaled_linear":
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
elif beta_schedule == "linear":
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
if rescale_zero_terminal_snr:
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
self.alphas_cumprod = self.alphas_cumprod.tolist()
self.set_timesteps(10)
self.prediction_type = prediction_type
self.training = False
def rescale_zero_terminal_snr(self, alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
return alphas_bar
def set_timesteps(self, num_inference_steps, denoising_strength=1.0, training=False, **kwargs):
# The timesteps are aligned to 999...0, which is different from other implementations,
# but I think this implementation is more reasonable in theory.
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
num_inference_steps = min(num_inference_steps, max_timestep + 1)
if num_inference_steps == 1:
self.timesteps = torch.Tensor([max_timestep])
else:
step_length = max_timestep / (num_inference_steps - 1)
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
self.training = training
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
if self.prediction_type == "epsilon":
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
prev_sample = sample * weight_x + model_output * weight_e
elif self.prediction_type == "v_prediction":
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
prev_sample = sample * weight_x + model_output * weight_e
else:
raise NotImplementedError(f"{self.prediction_type} is not implemented")
return prev_sample
def step(self, model_output, timestep, sample, to_final=False):
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
if to_final or timestep_id + 1 >= len(self.timesteps):
alpha_prod_t_prev = 1.0
else:
timestep_prev = int(self.timesteps[timestep_id + 1])
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
def return_to_timestep(self, timestep, sample, sample_stablized):
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
return noise_pred
def add_noise(self, original_samples, noise, timestep):
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def training_target(self, sample, noise, timestep):
if self.prediction_type == "epsilon":
return noise
else:
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return target
def training_weight(self, timestep):
return 1.0

View File

@@ -1,5 +1,5 @@
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTModel, DINOv3ViTConfig
from transformers import DINOv3ViTImageProcessor
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch
from ..core.device.npu_compatible_device import get_device_type
@@ -40,7 +40,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
value_bias = False
)
super().__init__(config)
self.processor = DINOv3ViTImageProcessor(
self.processor = DINOv3ViTImageProcessorFast(
crop_size = None,
data_format = "channels_first",
default_to_square = True,
@@ -56,7 +56,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
0.456,
0.406
],
image_processor_type = "DINOv3ViTImageProcessor",
image_processor_type = "DINOv3ViTImageProcessorFast",
image_std = [
0.229,
0.224,
@@ -82,7 +82,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values)
for i, layer_module in enumerate(self.model.layer):
for i, layer_module in enumerate(self.layer):
layer_head_mask = head_mask[i] if head_mask is not None else None
hidden_states = layer_module(
hidden_states,

View File

@@ -1,11 +1,11 @@
from transformers.models.siglip.modeling_siglip import SiglipVisionModel, SiglipVisionConfig
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessor
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
import torch
from diffsynth.core.device.npu_compatible_device import get_device_type
class Siglip2ImageEncoder(SiglipVisionModel):
class Siglip2ImageEncoder(SiglipVisionTransformer):
def __init__(self):
config = SiglipVisionConfig(
attention_dropout = 0.0,
@@ -90,7 +90,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
transformers_version = "4.57.1"
)
super().__init__(config)
self.processor = Siglip2ImageProcessor(
self.processor = Siglip2ImageProcessorFast(
**{
"data_format": "channels_first",
"default_to_square": True,
@@ -106,7 +106,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
0.5,
0.5
],
"image_processor_type": "Siglip2ImageProcessor",
"image_processor_type": "Siglip2ImageProcessorFast",
"image_std": [
0.5,
0.5,

View File

@@ -0,0 +1,78 @@
import torch
class SDTextEncoder(torch.nn.Module):
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
vocab_size=49408,
layer_norm_eps=1e-05,
hidden_act="quick_gelu",
initializer_factor=1.0,
initializer_range=0.02,
bos_token_id=0,
eos_token_id=2,
pad_token_id=1,
projection_dim=768,
):
super().__init__()
from transformers import CLIPConfig, CLIPTextModel
config = CLIPConfig(
text_config={
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"num_hidden_layers": num_hidden_layers,
"num_attention_heads": num_attention_heads,
"max_position_embeddings": max_position_embeddings,
"vocab_size": vocab_size,
"layer_norm_eps": layer_norm_eps,
"hidden_act": hidden_act,
"initializer_factor": initializer_factor,
"initializer_range": initializer_range,
"bos_token_id": bos_token_id,
"eos_token_id": eos_token_id,
"pad_token_id": pad_token_id,
"projection_dim": projection_dim,
"dropout": 0.0,
},
vision_config={
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"num_hidden_layers": num_hidden_layers,
"num_attention_heads": num_attention_heads,
"max_position_embeddings": max_position_embeddings,
"layer_norm_eps": layer_norm_eps,
"hidden_act": hidden_act,
"initializer_factor": initializer_factor,
"initializer_range": initializer_range,
"projection_dim": projection_dim,
},
projection_dim=projection_dim,
)
self.model = CLIPTextModel(config.text_config)
self.config = config
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_hidden_states=True,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=output_hidden_states,
return_dict=True,
**kwargs,
)
if output_hidden_states:
return outputs.last_hidden_state, outputs.hidden_states
return outputs.last_hidden_state

View File

@@ -0,0 +1,912 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
# ===== Time Embedding =====
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.freq_shift = freq_shift
def forward(self, timesteps):
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / half_dim + self.freq_shift
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
if self.flip_sin_to_cos:
emb = torch.cat([cos_emb, sin_emb], dim=-1)
else:
emb = torch.cat([sin_emb, cos_emb], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
out_dim = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
# ===== ResNet Blocks =====
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nn.SiLU()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
elif non_linearity == "relu":
self.nonlinearity = nn.ReLU()
self.use_conv_shortcut = conv_shortcut
self.conv_shortcut = None
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
# ===== Transformer Blocks =====
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, dropout=0.0):
super().__init__()
self.net = nn.ModuleList([
GEGLU(dim, dim * 4),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
])
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class Attention(nn.Module):
"""Attention block matching diffusers checkpoint key format.
Keys: to_q.weight, to_k.weight, to_v.weight, to_out.0.weight, to_out.0.bias
"""
def __init__(
self,
query_dim,
heads=8,
dim_head=64,
dropout=0.0,
bias=False,
upcast_attention=False,
cross_attention_dim=None,
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.inner_dim = inner_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_out = nn.ModuleList([
nn.Linear(inner_dim, query_dim, bias=True),
nn.Dropout(dropout),
])
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
# Query
query = self.to_q(hidden_states)
batch_size, seq_len, _ = query.shape
# Key/Value
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
# Reshape for multi-head attention
head_dim = self.inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
# Scaled dot-product attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Reshape back
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
hidden_states = hidden_states.to(query.dtype)
# Output projection
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
cross_attention_dim=None,
upcast_attention=False,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
)
self.norm2 = nn.LayerNorm(dim)
self.attn2 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
cross_attention_dim=cross_attention_dim,
)
self.norm3 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, dropout=dropout)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
# Self-attention
attn_output = self.attn1(self.norm1(hidden_states))
hidden_states = attn_output + hidden_states
# Cross-attention
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# Feed-forward
ff_output = self.ff(self.norm3(hidden_states))
hidden_states = ff_output + hidden_states
return hidden_states
class Transformer2DModel(nn.Module):
"""2D Transformer block wrapper matching diffusers checkpoint structure.
Keys: norm.weight/bias, proj_in.weight/bias, transformer_blocks.X.*, proj_out.weight/bias
"""
def __init__(
self,
num_attention_heads=16,
attention_head_dim=64,
in_channels=320,
num_layers=1,
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
upcast_attention=False,
):
super().__init__()
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
self.proj_in = nn.Conv2d(in_channels, num_attention_heads * attention_head_dim, kernel_size=1, bias=True)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
dim=num_attention_heads * attention_head_dim,
n_heads=num_attention_heads,
d_head=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
)
for _ in range(num_layers)
])
self.proj_out = nn.Conv2d(num_attention_heads * attention_head_dim, in_channels, kernel_size=1, bias=True)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch, channel, height, width = hidden_states.shape
residual = hidden_states
# Normalize and project to sequence
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
# Transformer blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
# Project back to 2D
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
# ===== Down/Up Blocks =====
class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
downsample=True,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class DownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
downsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class CrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
upsample=True,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet, attn in zip(self.resnets, self.attentions):
# Pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
class UpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
upsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet in self.resnets:
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
# ===== UNet Mid Block =====
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# There is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=in_channels // attention_head_dim,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
# ===== Downsample / Upsample =====
class Downsample2D(nn.Module):
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
self.padding = padding
def forward(self, hidden_states):
if self.padding == 0:
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(hidden_states)
class Upsample2D(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, hidden_states, upsample_size=None):
if upsample_size is not None:
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
return self.conv(hidden_states)
# ===== UNet2DConditionModel =====
class UNet2DConditionModel(nn.Module):
"""Stable Diffusion UNet with cross-attention conditioning.
state_dict keys match the diffusers UNet2DConditionModel checkpoint format.
"""
def __init__(
self,
sample_size=64,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
cross_attention_dim=768,
attention_head_dim=8,
norm_num_groups=32,
norm_eps=1e-5,
dropout=0.0,
act_fn="silu",
time_embedding_type="positional",
flip_sin_to_cos=True,
freq_shift=0,
time_embedding_dim=None,
resnet_time_scale_shift="default",
upcast_attention=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.sample_size = sample_size
# Time embedding
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
# Input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
# Down blocks
self.down_blocks = nn.ModuleList()
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if "CrossAttn" in down_block_type:
down_block = CrossAttnDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
transformer_layers_per_block=1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
downsample=not is_final_block,
)
else:
down_block = DownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
downsample=not is_final_block,
)
self.down_blocks.append(down_block)
# Mid block
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
)
# Up blocks
self.up_blocks = nn.ModuleList()
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
# in_channels for up blocks: diffusers uses reversed_block_out_channels[min(i+1, len-1)]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
if "CrossAttn" in up_block_type:
up_block = CrossAttnUpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
transformer_layers_per_block=1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
upsample=not is_final_block,
)
else:
up_block = UpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
upsample=not is_final_block,
)
self.up_blocks.append(up_block)
# Output
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
# 1. Time embedding
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
# 2. Pre-process
sample = self.conv_in(sample)
# 3. Down
down_block_res_samples = (sample,)
for down_block in self.down_blocks:
sample, res_samples = down_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
down_block_res_samples += res_samples
# 4. Mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. Up
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-len(up_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
sample = up_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
)
# 6. Post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return sample

View File

@@ -0,0 +1,642 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Optional
class DiagonalGaussianDistribution:
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# randn_like doesn't accept generator on all torch versions
sample = torch.randn(self.mean.shape, generator=generator,
device=self.parameters.device, dtype=self.parameters.dtype)
return self.mean + self.std * sample
def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
if self.deterministic:
return torch.tensor([0.0])
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3],
)
def mode(self) -> torch.Tensor:
return self.mean
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nn.SiLU()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
elif non_linearity == "relu":
self.nonlinearity = nn.ReLU()
else:
raise ValueError(f"Unsupported non_linearity: {non_linearity}")
self.use_conv_shortcut = conv_shortcut
self.conv_shortcut = None
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class DownEncoderBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=downsample_padding)
])
else:
self.downsamplers = None
def forward(self, hidden_states, *args, **kwargs):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class UpDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
output_scale_factor=1.0,
add_upsample=True,
temb_channels=None,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, temb=None):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class UNetMidBlock2D(nn.Module):
def __init__(
self,
in_channels,
temb_channels=None,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
add_attention=True,
attention_head_dim=1,
output_scale_factor=1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
if attention_head_dim is None:
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
AttentionBlock(
in_channels,
num_groups=resnet_groups,
eps=resnet_eps,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class AttentionBlock(nn.Module):
"""Simple attention block for VAE mid block.
Mirrors diffusers Attention class with AttnProcessor2_0 for VAE use case.
Uses modern key names (to_q, to_k, to_v, to_out) matching in-memory diffusers structure.
Checkpoint uses deprecated keys (query, key, value, proj_attn) — mapped via converter.
"""
def __init__(self, channels, num_groups=32, eps=1e-6):
super().__init__()
self.channels = channels
self.eps = eps
self.heads = 1
self.rescale_output_factor = 1.0
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=eps, affine=True)
self.to_q = nn.Linear(channels, channels, bias=True)
self.to_k = nn.Linear(channels, channels, bias=True)
self.to_v = nn.Linear(channels, channels, bias=True)
self.to_out = nn.ModuleList([
nn.Linear(channels, channels, bias=True),
nn.Dropout(0.0),
])
def forward(self, hidden_states):
residual = hidden_states
# Group norm
hidden_states = self.group_norm(hidden_states)
# Flatten spatial dims: (B, C, H, W) -> (B, H*W, C)
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# QKV projection
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
# Reshape for attention: (B, seq, dim) -> (B, heads, seq, head_dim)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
# Scaled dot-product attention
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Reshape back: (B, heads, seq, head_dim) -> (B, seq, heads*head_dim)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Output projection + dropout
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
# Reshape back to 4D and add residual
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
hidden_states = hidden_states + residual
# Rescale output factor
hidden_states = hidden_states / self.rescale_output_factor
return hidden_states
class Downsample2D(nn.Module):
"""Downsampling layer matching diffusers Downsample2D with use_conv=True.
Key names: conv.weight/bias.
When padding=0, applies explicit F.pad before conv to match dimension.
"""
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.padding = padding
def forward(self, hidden_states):
if self.padding == 0:
import torch.nn.functional as F
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(hidden_states)
class Upsample2D(nn.Module):
"""Upsampling layer with key names matching diffusers checkpoint: conv.weight/bias."""
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, hidden_states):
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
return self.conv(hidden_states)
class Encoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
self.down_blocks = nn.ModuleList([])
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = DownEncoderBlock2D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_downsample=not is_final_block,
downsample_padding=0,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
def forward(self, sample):
sample = self.conv_in(sample)
for down_block in self.down_blocks:
sample = down_block(sample)
sample = self.mid_block(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class Decoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
norm_type="group",
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpDecoderBlock2D(
in_channels=prev_output_channel,
out_channels=output_channel,
num_layers=self.layers_per_block + 1,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_upsample=not is_final_block,
temb_channels=temb_channels,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(self, sample, latent_embeds=None):
sample = self.conv_in(sample)
sample = self.mid_block(sample, latent_embeds)
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class StableDiffusionVAE(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
block_out_channels=(128, 256, 512, 512),
layers_per_block=2,
act_fn="silu",
latent_channels=4,
norm_num_groups=32,
sample_size=512,
scaling_factor=0.18215,
shift_factor=None,
latents_mean=None,
latents_std=None,
force_upcast=True,
use_quant_conv=True,
use_post_quant_conv=True,
mid_block_add_attention=True,
):
super().__init__()
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
double_z=True,
mid_block_add_attention=mid_block_add_attention,
)
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
mid_block_add_attention=mid_block_add_attention,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
self.latents_mean = latents_mean
self.latents_std = latents_std
self.scaling_factor = scaling_factor
self.shift_factor = shift_factor
self.sample_size = sample_size
self.force_upcast = force_upcast
def _encode(self, x):
h = self.encoder(x)
if self.quant_conv is not None:
h = self.quant_conv(h)
return h
def encode(self, x):
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
return posterior
def _decode(self, z):
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
return self.decoder(z)
def decode(self, z):
return self._decode(z)
def forward(self, sample, sample_posterior=True, return_dict=True, generator=None):
posterior = self.encode(sample)
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
# Scale latent
z = z * self.scaling_factor
decode = self.decode(z)
if return_dict:
return {"sample": decode, "posterior": posterior, "latent_sample": z}
return decode, posterior

View File

@@ -0,0 +1,62 @@
import torch
class SDXLTextEncoder2(torch.nn.Module):
def __init__(
self,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
vocab_size=49408,
layer_norm_eps=1e-05,
hidden_act="gelu",
initializer_factor=1.0,
initializer_range=0.02,
bos_token_id=0,
eos_token_id=2,
pad_token_id=1,
projection_dim=1280,
):
super().__init__()
from transformers import CLIPTextConfig, CLIPTextModelWithProjection
config = CLIPTextConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
max_position_embeddings=max_position_embeddings,
vocab_size=vocab_size,
layer_norm_eps=layer_norm_eps,
hidden_act=hidden_act,
initializer_factor=initializer_factor,
initializer_range=initializer_range,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
projection_dim=projection_dim,
)
self.model = CLIPTextModelWithProjection(config)
self.config = config
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_hidden_states=True,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=output_hidden_states,
return_dict=True,
**kwargs,
)
if output_hidden_states:
return outputs.text_embeds, outputs.hidden_states
return outputs.text_embeds

View File

@@ -0,0 +1,922 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
# ===== Time Embedding =====
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.freq_shift = freq_shift
def forward(self, timesteps):
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / half_dim + self.freq_shift
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
if self.flip_sin_to_cos:
emb = torch.cat([cos_emb, sin_emb], dim=-1)
else:
emb = torch.cat([sin_emb, cos_emb], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
out_dim = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
# ===== ResNet Blocks =====
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nn.SiLU()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
elif non_linearity == "relu":
self.nonlinearity = nn.ReLU()
self.use_conv_shortcut = conv_shortcut
self.conv_shortcut = None
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
# ===== Transformer Blocks =====
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, dropout=0.0):
super().__init__()
self.net = nn.ModuleList([
GEGLU(dim, dim * 4),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
])
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class Attention(nn.Module):
def __init__(
self,
query_dim,
heads=8,
dim_head=64,
dropout=0.0,
bias=False,
upcast_attention=False,
cross_attention_dim=None,
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.inner_dim = inner_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_out = nn.ModuleList([
nn.Linear(inner_dim, query_dim, bias=True),
nn.Dropout(dropout),
])
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
query = self.to_q(hidden_states)
batch_size, seq_len, _ = query.shape
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
head_dim = self.inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
hidden_states = hidden_states.to(query.dtype)
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
cross_attention_dim=None,
upcast_attention=False,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
)
self.norm2 = nn.LayerNorm(dim)
self.attn2 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
cross_attention_dim=cross_attention_dim,
)
self.norm3 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, dropout=dropout)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
attn_output = self.attn1(self.norm1(hidden_states))
hidden_states = attn_output + hidden_states
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
ff_output = self.ff(self.norm3(hidden_states))
hidden_states = ff_output + hidden_states
return hidden_states
class Transformer2DModel(nn.Module):
def __init__(
self,
num_attention_heads=16,
attention_head_dim=64,
in_channels=320,
num_layers=1,
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
upcast_attention=False,
use_linear_projection=False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.use_linear_projection = use_linear_projection
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim, bias=True)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, bias=True)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
dim=inner_dim,
n_heads=num_attention_heads,
d_head=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
)
for _ in range(num_layers)
])
if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels, bias=True)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, bias=True)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch, channel, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if self.use_linear_projection:
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
hidden_states = self.proj_in(hidden_states)
else:
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
if self.use_linear_projection:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
else:
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
# ===== Down/Up Blocks =====
class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
downsample=True,
use_linear_projection=False,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class DownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
downsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class CrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
upsample=True,
use_linear_projection=False,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet, attn in zip(self.resnets, self.attentions):
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
class UpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
upsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet in self.resnets:
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
# ===== UNet Mid Block =====
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
use_linear_projection=False,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=in_channels // attention_head_dim,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
# ===== Downsample / Upsample =====
class Downsample2D(nn.Module):
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
self.padding = padding
def forward(self, hidden_states):
if self.padding == 0:
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(hidden_states)
class Upsample2D(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, hidden_states, upsample_size=None):
if upsample_size is not None:
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
return self.conv(hidden_states)
# ===== SDXL UNet2DConditionModel =====
class SDXLUNet2DConditionModel(nn.Module):
def __init__(
self,
sample_size=128,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
block_out_channels=(320, 640, 1280),
layers_per_block=2,
cross_attention_dim=2048,
attention_head_dim=5,
transformer_layers_per_block=1,
norm_num_groups=32,
norm_eps=1e-5,
dropout=0.0,
act_fn="silu",
time_embedding_type="positional",
flip_sin_to_cos=True,
freq_shift=0,
time_embedding_dim=None,
resnet_time_scale_shift="default",
upcast_attention=False,
use_linear_projection=False,
addition_embed_type=None,
addition_time_embed_dim=None,
projection_class_embeddings_input_dim=None,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.sample_size = sample_size
self.addition_embed_type = addition_embed_type
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
if addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
self.down_blocks = nn.ModuleList()
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if "CrossAttn" in down_block_type:
down_block = CrossAttnDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim[i],
downsample=not is_final_block,
use_linear_projection=use_linear_projection,
)
else:
down_block = DownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
downsample=not is_final_block,
)
self.down_blocks.append(down_block)
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=1,
transformer_layers_per_block=transformer_layers_per_block[-1],
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim[-1],
use_linear_projection=use_linear_projection,
)
self.up_blocks = nn.ModuleList()
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
if "CrossAttn" in up_block_type:
up_block = CrossAttnUpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=reversed_attention_head_dim[i],
upsample=not is_final_block,
use_linear_projection=use_linear_projection,
)
else:
up_block = UpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
upsample=not is_final_block,
)
self.up_blocks.append(up_block)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
if self.addition_embed_type == "text_time":
text_embeds = added_cond_kwargs.get("text_embeds")
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb
sample = self.conv_in(sample)
down_block_res_samples = (sample,)
for down_block in self.down_blocks:
sample, res_samples = down_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
down_block_res_samples += res_samples
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-len(up_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
sample = up_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return sample

View File

@@ -74,7 +74,7 @@ class AnimaImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str = "",
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 4.0,
# Image

View File

@@ -75,7 +75,7 @@ class Flux2ImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str = "",
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
embedded_guidance: float = 4.0,
@@ -83,7 +83,7 @@ class Flux2ImagePipeline(BasePipeline):
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Edit
edit_image: List[Image.Image] = None,
edit_image: Union[Image.Image, List[Image.Image]] = None,
edit_image_auto_resize: bool = True,
# Shape
height: int = 1024,

View File

@@ -181,7 +181,7 @@ class FluxImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str = "",
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
embedded_guidance: float = 3.5,
@@ -199,6 +199,10 @@ class FluxImagePipeline(BasePipeline):
sigma_shift: float = None,
# Steps
num_inference_steps: int = 30,
# local prompts
multidiffusion_prompts=(),
multidiffusion_masks=(),
multidiffusion_scales=(),
# Kontext
kontext_images: Union[list[Image.Image], Image.Image] = None,
# ControlNet
@@ -253,6 +257,7 @@ class FluxImagePipeline(BasePipeline):
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
"multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
"kontext_images": kontext_images,
"controlnet_inputs": controlnet_inputs,
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,

View File

@@ -169,46 +169,46 @@ class LTX2AudioVideoPipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str = "",
negative_prompt: str = "",
prompt: str,
negative_prompt: Optional[str] = "",
denoising_strength: float = 1.0,
# Image-to-video
input_images: list[Image.Image] = None,
input_images_indexes: list[int] = [0],
input_images_strength: float = 1.0,
input_images: Optional[list[Image.Image]] = None,
input_images_indexes: Optional[list[int]] = [0],
input_images_strength: Optional[float] = 1.0,
# In-Context Video Control
in_context_videos: list[list[Image.Image]] = None,
in_context_downsample_factor: int = 2,
in_context_videos: Optional[list[list[Image.Image]]] = None,
in_context_downsample_factor: Optional[int] = 2,
# Video-to-video
retake_video: list[Image.Image] = None,
retake_video_regions: list[tuple[float, float]] = None,
retake_video: Optional[list[Image.Image]] = None,
retake_video_regions: Optional[list[tuple[float, float]]] = None,
# Audio-to-video
retake_audio: torch.Tensor = None,
audio_sample_rate: int = 48000,
retake_audio_regions: list[tuple[float, float]] = None,
retake_audio: Optional[torch.Tensor] = None,
audio_sample_rate: Optional[int] = 48000,
retake_audio_regions: Optional[list[tuple[float, float]]] = None,
# Randomness
seed: int = None,
rand_device: str = "cpu",
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height: int = 512,
width: int = 768,
num_frames: int = 121,
frame_rate: int = 24,
height: Optional[int] = 512,
width: Optional[int] = 768,
num_frames: Optional[int] = 121,
frame_rate: Optional[int] = 24,
# Classifier-free guidance
cfg_scale: float = 3.0,
cfg_scale: Optional[float] = 3.0,
# Scheduler
num_inference_steps: int = 30,
num_inference_steps: Optional[int] = 30,
# VAE tiling
tiled: bool = True,
tile_size_in_pixels: int = 512,
tile_overlap_in_pixels: int = 128,
tile_size_in_frames: int = 128,
tile_overlap_in_frames: int = 24,
tiled: Optional[bool] = True,
tile_size_in_pixels: Optional[int] = 512,
tile_overlap_in_pixels: Optional[int] = 128,
tile_size_in_frames: Optional[int] = 128,
tile_overlap_in_frames: Optional[int] = 24,
# Special Pipelines
use_two_stage_pipeline: bool = False,
stage2_spatial_upsample_factor: int = 2,
clear_lora_before_state_two: bool = False,
use_distilled_pipeline: bool = False,
use_two_stage_pipeline: Optional[bool] = False,
stage2_spatial_upsample_factor: Optional[int] = 2,
clear_lora_before_state_two: Optional[bool] = False,
use_distilled_pipeline: Optional[bool] = False,
# progress_bar
progress_bar_cmd=tqdm,
):

View File

@@ -115,33 +115,33 @@ class MovaAudioVideoPipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str = "",
negative_prompt: str = "",
prompt: str,
negative_prompt: Optional[str] = "",
# Image-to-video
input_image: Image.Image = None,
input_image: Optional[Image.Image] = None,
# First-last-frame-to-video
end_image: Image.Image = None,
end_image: Optional[Image.Image] = None,
# Video-to-video
denoising_strength: float = 1.0,
denoising_strength: Optional[float] = 1.0,
# Randomness
seed: int = None,
rand_device: str = "cpu",
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height: int = 352,
width: int = 640,
num_frames: int = 81,
frame_rate: int = 24,
height: Optional[int] = 352,
width: Optional[int] = 640,
num_frames: Optional[int] = 81,
frame_rate: Optional[int] = 24,
# Classifier-free guidance
cfg_scale: float = 5.0,
cfg_scale: Optional[float] = 5.0,
# Boundary
switch_DiT_boundary: float = 0.9,
switch_DiT_boundary: Optional[float] = 0.9,
# Scheduler
num_inference_steps: int = 50,
sigma_shift: float = 5.0,
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
# VAE tiling
tiled: bool = True,
tile_size: tuple[int, int] = (30, 52),
tile_stride: tuple[int, int] = (15, 26),
tiled: Optional[bool] = True,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# progress_bar
progress_bar_cmd=tqdm,
):

View File

@@ -100,7 +100,7 @@ class QwenImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str = "",
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 4.0,
# Image

View File

@@ -0,0 +1,230 @@
import torch
from PIL import Image
from tqdm import tqdm
from typing import Union
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion.ddim_scheduler import DDIMScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from transformers import AutoTokenizer, CLIPTextModel
from ..models.stable_diffusion_text_encoder import SDTextEncoder
from ..models.stable_diffusion_unet import UNet2DConditionModel
from ..models.stable_diffusion_vae import StableDiffusionVAE
class StableDiffusionPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.float16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=8, width_division_factor=8,
)
self.scheduler = DDIMScheduler()
self.text_encoder: SDTextEncoder = None
self.unet: UNet2DConditionModel = None
self.vae: StableDiffusionVAE = None
self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("unet",)
self.units = [
SDUnit_ShapeChecker(),
SDUnit_PromptEmbedder(),
SDUnit_NoiseInitializer(),
SDUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_stable_diffusion
self.compilable_models = ["unet"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.float16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
vram_limit: float = None,
):
pipe = StableDiffusionPipeline(device=device, torch_dtype=torch_dtype)
# Override vram_config to use the specified torch_dtype for all models
for mc in model_configs:
mc._vram_config_override = {
'onload_dtype': torch_dtype,
'computation_dtype': torch_dtype,
}
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder")
pipe.unet = model_pool.fetch_model("stable_diffusion_unet")
pipe.vae = model_pool.fetch_model("stable_diffusion_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 7.5,
height: int = 512,
width: int = 512,
seed: int = None,
rand_device: str = "cpu",
num_inference_steps: int = 50,
eta: float = 0.0,
guidance_rescale: float = 0.0,
progress_bar_cmd=tqdm,
):
# 1. Scheduler
self.scheduler.set_timesteps(
num_inference_steps, eta=eta,
)
# 2. Three-dict input preparation
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"guidance_rescale": guidance_rescale,
}
# 3. Unit chain execution
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# 4. Denoise loop
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
)
# 5. VAE decode
self.load_models_to_device(['vae'])
latents = inputs_shared["latents"] / self.vae.scaling_factor
image = self.vae.decode(latents)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class SDUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: StableDiffusionPipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class SDUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
def encode_prompt(
self,
pipe: StableDiffusionPipeline,
prompt: str,
device: torch.device,
) -> torch.Tensor:
text_inputs = pipe.tokenizer(
prompt,
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_embeds = pipe.text_encoder(text_input_ids)
# TextEncoder returns (last_hidden_state, hidden_states) or just last_hidden_state.
# last_hidden_state is the post-final-layer-norm output, matching diffusers encode_prompt.
if isinstance(prompt_embeds, tuple):
prompt_embeds = prompt_embeds[0]
return prompt_embeds
def process(self, pipe: StableDiffusionPipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_embeds": prompt_embeds}
class SDUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: StableDiffusionPipeline, height, width, seed, rand_device):
noise = pipe.generate_noise(
(1, pipe.unet.in_channels, height // 8, width // 8),
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
)
return {"noise": noise}
class SDUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe: StableDiffusionPipeline, input_image, noise):
if input_image is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
if pipe.scheduler.training:
return {"latents": latents, "input_latents": input_latents}
else:
return {"latents": latents}
def model_fn_stable_diffusion(
unet: UNet2DConditionModel,
latents=None,
timestep=None,
prompt_embeds=None,
cross_attention_kwargs=None,
timestep_cond=None,
added_cond_kwargs=None,
**kwargs,
):
# SD timestep is already in 0-999 range, no scaling needed
noise_pred = unet(
latents,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
timestep_cond=timestep_cond,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
return noise_pred

View File

@@ -0,0 +1,331 @@
import torch
from PIL import Image
from tqdm import tqdm
from typing import Union
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion.ddim_scheduler import DDIMScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from transformers import AutoTokenizer, CLIPTextModel
from ..models.stable_diffusion_text_encoder import SDTextEncoder
from ..models.stable_diffusion_xl_unet import SDXLUNet2DConditionModel
from ..models.stable_diffusion_xl_text_encoder import SDXLTextEncoder2
from ..models.stable_diffusion_vae import StableDiffusionVAE
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""Rescale noise_cfg based on guidance_rescale to prevent overexposure.
Based on Section 3.4 from "Common Diffusion Noise Schedules and Sample Steps are Flawed"
https://huggingface.co/papers/2305.08891
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class StableDiffusionXLPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=8, width_division_factor=8,
)
self.scheduler = DDIMScheduler()
self.text_encoder: SDTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.unet: SDXLUNet2DConditionModel = None
self.vae: StableDiffusionVAE = None
self.tokenizer: AutoTokenizer = None
self.tokenizer_2: AutoTokenizer = None
self.in_iteration_models = ("unet",)
self.units = [
SDXLUnit_ShapeChecker(),
SDXLUnit_PromptEmbedder(),
SDXLUnit_NoiseInitializer(),
SDXLUnit_InputImageEmbedder(),
SDXLUnit_AddTimeIdsComputer(),
]
self.model_fn = model_fn_stable_diffusion_xl
self.compilable_models = ["unet"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
tokenizer_2_config: ModelConfig = None,
vram_limit: float = None,
):
pipe = StableDiffusionXLPipeline(device=device, torch_dtype=torch_dtype)
# Override vram_config to use the specified torch_dtype for all models
for mc in model_configs:
mc._vram_config_override = {
'onload_dtype': torch_dtype,
'computation_dtype': torch_dtype,
}
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder")
pipe.text_encoder_2 = model_pool.fetch_model("stable_diffusion_xl_text_encoder")
pipe.unet = model_pool.fetch_model("stable_diffusion_xl_unet")
pipe.vae = model_pool.fetch_model("stable_diffusion_xl_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
if tokenizer_2_config is not None:
tokenizer_2_config.download_if_necessary()
pipe.tokenizer_2 = AutoTokenizer.from_pretrained(tokenizer_2_config.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 5.0,
height: int = 1024,
width: int = 1024,
seed: int = None,
rand_device: str = "cpu",
num_inference_steps: int = 50,
guidance_rescale: float = 0.0,
progress_bar_cmd=tqdm,
):
# 1. Scheduler
self.scheduler.set_timesteps(num_inference_steps)
# 2. Three-dict input preparation
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"prompt": negative_prompt,
}
inputs_shared = {
"cfg_scale": cfg_scale,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"guidance_rescale": guidance_rescale,
"crops_coords_top_left": (0, 0),
}
# 3. Unit chain execution
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# 4. Denoise loop
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
# Apply guidance_rescale
if guidance_rescale > 0.0:
# cfg_guided_model_fn already applied CFG, now apply rescale
# We need the text-only prediction for rescale
noise_pred_text = self.model_fn(
self.unet,
inputs_shared["latents"],
timestep,
inputs_posi["prompt_embeds"],
pooled_prompt_embeds=inputs_posi["pooled_prompt_embeds"],
add_time_ids=inputs_posi["add_time_ids"],
)
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
)
inputs_shared["latents"] = self.step(
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
)
# 6. VAE decode
self.load_models_to_device(['vae'])
latents = inputs_shared["latents"] / self.vae.scaling_factor
image = self.vae.decode(latents)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class SDXLUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: StableDiffusionXLPipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class SDXLUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "prompt"},
output_params=("prompt_embeds", "pooled_prompt_embeds"),
onload_model_names=("text_encoder", "text_encoder_2")
)
def encode_prompt(
self,
pipe: StableDiffusionXLPipeline,
prompt: str,
device: torch.device,
) -> tuple:
"""Encode prompt using both text encoders (same prompt for both).
Returns (prompt_embeds, pooled_prompt_embeds):
- prompt_embeds: concat(encoder1_output, encoder2_output) -> (B, 77, 2048)
- pooled_prompt_embeds: encoder2 pooled output -> (B, 1280)
"""
# Text Encoder 1 (CLIP-L, 768-dim)
text_input_ids_1 = pipe.tokenizer(
prompt,
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids.to(device)
prompt_embeds_1 = pipe.text_encoder(text_input_ids_1)
if isinstance(prompt_embeds_1, tuple):
prompt_embeds_1 = prompt_embeds_1[0]
# Text Encoder 2 (CLIP-bigG, 1280-dim) — uses penultimate hidden states + pooled
text_input_ids_2 = pipe.tokenizer_2(
prompt,
padding="max_length",
max_length=pipe.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids.to(device)
# SDXLTextEncoder2 forward returns (text_embeds/pooled, hidden_states_tuple)
pooled_prompt_embeds, hidden_states = pipe.text_encoder_2(text_input_ids_2, output_hidden_states=True)
# Use penultimate hidden state (same as diffusers: hidden_states[-2])
prompt_embeds_2 = hidden_states[-2]
# Concatenate both encoder outputs along feature dimension
prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
return prompt_embeds, pooled_prompt_embeds
def process(self, pipe: StableDiffusionXLPipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
class SDXLUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: StableDiffusionXLPipeline, height, width, seed, rand_device):
noise = pipe.generate_noise(
(1, pipe.unet.in_channels, height // 8, width // 8),
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
)
return {"noise": noise}
class SDXLUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe: StableDiffusionXLPipeline, input_image, noise):
if input_image is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
if pipe.scheduler.training:
return {"latents": latents, "input_latents": input_latents}
else:
return {"latents": latents}
class SDXLUnit_AddTimeIdsComputer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("add_time_ids",),
)
def _get_add_time_ids(self, pipe, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
addition_time_embed_dim = pipe.unet.add_time_proj.num_channels
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
f"but a vector of {passed_add_embed_dim} was created."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=pipe.device)
return add_time_ids
def process(self, pipe: StableDiffusionXLPipeline, height, width):
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
pipe, original_size, crops_coords_top_left, target_size,
dtype=pipe.torch_dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
return {"add_time_ids": add_time_ids}
def model_fn_stable_diffusion_xl(
unet: SDXLUNet2DConditionModel,
latents=None,
timestep=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
add_time_ids=None,
cross_attention_kwargs=None,
timestep_cond=None,
**kwargs,
):
"""SDXL model forward with added_cond_kwargs for micro-conditioning."""
added_cond_kwargs = {
"text_embeds": pooled_prompt_embeds,
"time_ids": add_time_ids,
}
noise_pred = unet(
latents,
timestep,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
timestep_cond=timestep_cond,
return_dict=False,
)[0]
return noise_pred

View File

@@ -190,82 +190,82 @@ class WanVideoPipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str = "",
negative_prompt: str = "",
prompt: str,
negative_prompt: Optional[str] = "",
# Image-to-video
input_image: Image.Image = None,
input_image: Optional[Image.Image] = None,
# First-last-frame-to-video
end_image: Image.Image = None,
end_image: Optional[Image.Image] = None,
# Video-to-video
input_video: list[Image.Image] = None,
denoising_strength: float = 1.0,
input_video: Optional[list[Image.Image]] = None,
denoising_strength: Optional[float] = 1.0,
# Speech-to-video
input_audio: np.array = None,
audio_embeds: torch.Tensor = None,
audio_sample_rate: int = 16000,
s2v_pose_video: list[Image.Image] = None,
s2v_pose_latents: torch.Tensor = None,
motion_video: list[Image.Image] = None,
input_audio: Optional[np.array] = None,
audio_embeds: Optional[torch.Tensor] = None,
audio_sample_rate: Optional[int] = 16000,
s2v_pose_video: Optional[list[Image.Image]] = None,
s2v_pose_latents: Optional[torch.Tensor] = None,
motion_video: Optional[list[Image.Image]] = None,
# ControlNet
control_video: list[Image.Image] = None,
reference_image: Image.Image = None,
control_video: Optional[list[Image.Image]] = None,
reference_image: Optional[Image.Image] = None,
# Camera control
camera_control_direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"] = None,
camera_control_speed: float = 1/54,
camera_control_origin: tuple = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None,
camera_control_speed: Optional[float] = 1/54,
camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
# VACE
vace_video: list[Image.Image] = None,
vace_video_mask: Image.Image = None,
vace_reference_image: Image.Image = None,
vace_scale: float = 1.0,
vace_video: Optional[list[Image.Image]] = None,
vace_video_mask: Optional[Image.Image] = None,
vace_reference_image: Optional[Image.Image] = None,
vace_scale: Optional[float] = 1.0,
# Animate
animate_pose_video: list[Image.Image] = None,
animate_face_video: list[Image.Image] = None,
animate_inpaint_video: list[Image.Image] = None,
animate_mask_video: list[Image.Image] = None,
animate_pose_video: Optional[list[Image.Image]] = None,
animate_face_video: Optional[list[Image.Image]] = None,
animate_inpaint_video: Optional[list[Image.Image]] = None,
animate_mask_video: Optional[list[Image.Image]] = None,
# VAP
vap_video: list[Image.Image] = None,
vap_prompt: str = " ",
negative_vap_prompt: str = " ",
vap_video: Optional[list[Image.Image]] = None,
vap_prompt: Optional[str] = " ",
negative_vap_prompt: Optional[str] = " ",
# Randomness
seed: int = None,
rand_device: str = "cpu",
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height: int = 480,
width: int = 832,
num_frames: int = 81,
height: Optional[int] = 480,
width: Optional[int] = 832,
num_frames=81,
# Classifier-free guidance
cfg_scale: float = 5.0,
cfg_merge: bool = False,
cfg_scale: Optional[float] = 5.0,
cfg_merge: Optional[bool] = False,
# Boundary
switch_DiT_boundary: float = 0.875,
switch_DiT_boundary: Optional[float] = 0.875,
# Scheduler
num_inference_steps: int = 50,
sigma_shift: float = 5.0,
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
# Speed control
motion_bucket_id: int = None,
motion_bucket_id: Optional[int] = None,
# LongCat-Video
longcat_video: list[Image.Image] = None,
longcat_video: Optional[list[Image.Image]] = None,
# VAE tiling
tiled: bool = True,
tile_size: tuple[int, int] = (30, 52),
tile_stride: tuple[int, int] = (15, 26),
tiled: Optional[bool] = True,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# Sliding window
sliding_window_size: int = None,
sliding_window_stride: int = None,
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
# Teacache
tea_cache_l1_thresh: float = None,
tea_cache_model_id: str = "",
tea_cache_l1_thresh: Optional[float] = None,
tea_cache_model_id: Optional[str] = "",
# WanToDance
wantodance_music_path: str = None,
wantodance_reference_image: Image.Image = None,
wantodance_fps: float = 30,
wantodance_keyframes: list[Image.Image] = None,
wantodance_keyframes_mask: list[int] = None,
wantodance_music_path: Optional[str] = None,
wantodance_reference_image: Optional[Image.Image] = None,
wantodance_fps: Optional[float] = 30,
wantodance_keyframes: Optional[list[Image.Image]] = None,
wantodance_keyframes_mask: Optional[list[int]] = None,
framewise_decoding: bool = False,
# progress_bar
progress_bar_cmd=tqdm,
output_type: Literal["quantized", "floatpoint"] = "quantized",
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)

View File

@@ -95,7 +95,7 @@ class ZImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str = "",
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Image
@@ -109,7 +109,7 @@ class ZImagePipeline(BasePipeline):
width: int = 1024,
# Randomness
seed: int = None,
rand_device: Union[str, torch.device] = "cpu",
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 8,
sigma_shift: float = None,

View File

@@ -1,9 +0,0 @@
def DINOv3StateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
value = state_dict[key]
if key.startswith("layer"):
new_state_dict["model." + key] = value
else:
new_state_dict[key] = value
return new_state_dict

View File

@@ -0,0 +1,7 @@
def SDTextEncoderStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if key.startswith("text_model.") and "position_ids" not in key:
new_key = "model." + key
new_state_dict[new_key] = state_dict[key]
return new_state_dict

View File

@@ -0,0 +1,18 @@
def SDVAEStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if ".query." in key:
new_key = key.replace(".query.", ".to_q.")
new_state_dict[new_key] = state_dict[key]
elif ".key." in key:
new_key = key.replace(".key.", ".to_k.")
new_state_dict[new_key] = state_dict[key]
elif ".value." in key:
new_key = key.replace(".value.", ".to_v.")
new_state_dict[new_key] = state_dict[key]
elif ".proj_attn." in key:
new_key = key.replace(".proj_attn.", ".to_out.0.")
new_state_dict[new_key] = state_dict[key]
else:
new_state_dict[key] = state_dict[key]
return new_state_dict

View File

@@ -0,0 +1,13 @@
import torch
def SDXLTextEncoder2StateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if key == "text_projection.weight":
val = state_dict[key]
new_state_dict["model.text_projection.weight"] = val.float() if val.dtype == torch.float16 else val
elif key.startswith("text_model.") and "position_ids" not in key:
new_key = "model." + key
val = state_dict[key]
new_state_dict[new_key] = val.float() if val.dtype == torch.float16 else val
return new_state_dict

View File

@@ -0,0 +1,141 @@
# Stable Diffusion XL
Stable Diffusion XL (SDXL) is an open-source diffusion-based text-to-image generation model developed by Stability AI, supporting 1024x1024 resolution high-quality text-to-image generation with a dual text encoder (CLIP-L + CLIP-bigG) architecture.
## Installation
Before performing model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Running the following code will quickly load the [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 6GB VRAM.
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionXLPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars",
negative_prompt="",
cfg_scale=5.0,
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
)
image.save("image.jpg")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
## Model Inference
The model is loaded via `StableDiffusionXLPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
The input parameters for `StableDiffusionXLPipeline` inference include:
* `prompt`: Text prompt.
* `negative_prompt`: Negative prompt, defaults to an empty string.
* `cfg_scale`: Classifier-Free Guidance scale factor, default 5.0.
* `height`: Output image height, default 1024.
* `width`: Output image width, default 1024.
* `seed`: Random seed, defaults to a random value if not set.
* `rand_device`: Noise generation device, defaults to "cpu".
* `num_inference_steps`: Number of inference steps, default 50.
* `guidance_rescale`: Guidance rescale factor, default 0.0.
* `progress_bar_cmd`: Progress bar callback function.
> `StableDiffusionXLPipeline` requires dual tokenizer configurations (`tokenizer_config` and `tokenizer_2_config`), corresponding to the CLIP-L and CLIP-bigG text encoders.
## Model Training
Models in the stable_diffusion_xl series are trained via `examples/stable_diffusion_xl/model_training/train.py`. The script parameters include:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Path to the dataset metadata file.
* `--dataset_repeat`: Number of dataset repeats per epoch.
* `--dataset_num_workers`: Number of processes per DataLoader.
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths to load models from, in JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
* Basic Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
* `--weight_decay`: Weight decay magnitude.
* `--task`: Training task, defaults to `sft`.
* Output Configuration
* `--output_path`: Path to save the model.
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
* `--save_steps`: Interval in training steps to save the model.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to LoRA checkpoint.
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Resolution Configuration
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
* `--num_frames`: Number of frames for video (video generation models only).
* Stable Diffusion XL Specific Parameters
* `--tokenizer_path`: Path to the first tokenizer.
* `--tokenizer_2_path`: Path to the second tokenizer, defaults to `stabilityai/stable-diffusion-xl-base-1.0:tokenizer_2/`.
Example dataset download:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/*" --local_dir ./data/diffsynth_example_dataset
```
[stable-diffusion-xl-base-1.0 training scripts](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -0,0 +1,138 @@
# Stable Diffusion
Stable Diffusion is an open-source diffusion-based text-to-image generation model developed by Stability AI, supporting 512x512 resolution text-to-image generation.
## Installation
Before performing model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Running the following code will quickly load the [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 2GB VRAM.
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image.jpg")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
## Model Inference
The model is loaded via `StableDiffusionPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
The input parameters for `StableDiffusionPipeline` inference include:
* `prompt`: Text prompt.
* `negative_prompt`: Negative prompt, defaults to an empty string.
* `cfg_scale`: Classifier-Free Guidance scale factor, default 7.5.
* `height`: Output image height, default 512.
* `width`: Output image width, default 512.
* `seed`: Random seed, defaults to a random value if not set.
* `rand_device`: Noise generation device, defaults to "cpu".
* `num_inference_steps`: Number of inference steps, default 50.
* `eta`: DDIM scheduler eta parameter, default 0.0.
* `guidance_rescale`: Guidance rescale factor, default 0.0.
* `progress_bar_cmd`: Progress bar callback function.
## Model Training
Models in the stable_diffusion series are trained via `examples/stable_diffusion/model_training/train.py`. The script parameters include:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Path to the dataset metadata file.
* `--dataset_repeat`: Number of dataset repeats per epoch.
* `--dataset_num_workers`: Number of processes per DataLoader.
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths to load models from, in JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
* Basic Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
* `--weight_decay`: Weight decay magnitude.
* `--task`: Training task, defaults to `sft`.
* Output Configuration
* `--output_path`: Path to save the model.
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
* `--save_steps`: Interval in training steps to save the model.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to LoRA checkpoint.
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Resolution Configuration
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
* `--num_frames`: Number of frames for video (video generation models only).
* Stable Diffusion Specific Parameters
* `--tokenizer_path`: Tokenizer path, defaults to `AI-ModelScope/stable-diffusion-v1-5:tokenizer/`.
Example dataset download:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/*" --local_dir ./data/diffsynth_example_dataset
```
[stable-diffusion-v1-5 training scripts](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -32,6 +32,8 @@ Welcome to DiffSynth-Studio's Documentation
Model_Details/LTX-2
Model_Details/ERNIE-Image
Model_Details/JoyAI-Image
Model_Details/Stable-Diffusion
Model_Details/Stable-Diffusion-XL
.. toctree::
:maxdepth: 2

View File

@@ -0,0 +1,141 @@
# Stable Diffusion XL
Stable Diffusion XL (SDXL) 是由 Stability AI 开发的开源扩散式文本到图像生成模型,支持 1024x1024 分辨率的高质量文本到图像生成采用双文本编码器CLIP-L + CLIP-bigG架构。
## 安装
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
## 快速开始
运行以下代码可以快速加载 [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 6GB 显存即可运行。
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionXLPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars",
negative_prompt="",
cfg_scale=5.0,
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
)
image.save("image.jpg")
```
## 模型总览
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
## 模型推理
模型通过 `StableDiffusionXLPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
`StableDiffusionXLPipeline` 的推理输入参数包括:
* `prompt`: 文本提示词。
* `negative_prompt`: 负面提示词,默认为空字符串。
* `cfg_scale`: Classifier-Free Guidance 缩放系数,默认 5.0。
* `height`: 输出图像高度,默认 1024。
* `width`: 输出图像宽度,默认 1024。
* `seed`: 随机种子,默认不设置时使用随机种子。
* `rand_device`: 噪声生成设备,默认 "cpu"。
* `num_inference_steps`: 推理步数,默认 50。
* `guidance_rescale`: Guidance rescale 系数,默认 0.0。
* `progress_bar_cmd`: 进度条回调函数。
> `StableDiffusionXLPipeline` 需要双 tokenizer 配置(`tokenizer_config` 和 `tokenizer_2_config`),分别对应 CLIP-L 和 CLIP-bigG 文本编码器。
## 模型训练
stable_diffusion_xl 系列模型通过 `examples/stable_diffusion_xl/model_training/train.py` 进行训练,脚本的参数包括:
* 通用训练参数
* 数据集基础配置
* `--dataset_base_path`: 数据集的根目录。
* `--dataset_metadata_path`: 数据集的元数据文件路径。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
* 模型加载配置
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
* 训练基础配置
* `--learning_rate`: 学习率。
* `--num_epochs`: 轮数Epoch
* `--trainable_models`: 可训练的模型,例如 `dit``vae``text_encoder`
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
* `--weight_decay`: 权重衰减大小。
* `--task`: 训练任务,默认为 `sft`
* 输出配置
* `--output_path`: 模型保存路径。
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
* `--save_steps`: 保存模型的训练步数间隔。
* LoRA 配置
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪些层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`
* 梯度配置
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
* `--gradient_accumulation_steps`: 梯度累积步数。
* 分辨率配置
* `--height`: 图像/视频的高度。留空启用动态分辨率。
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
* `--num_frames`: 视频的帧数(仅视频生成模型)。
* Stable Diffusion XL 专有参数
* `--tokenizer_path`: 第一个 Tokenizer 路径。
* `--tokenizer_2_path`: 第二个 Tokenizer 路径,默认为 `stabilityai/stable-diffusion-xl-base-1.0:tokenizer_2/`
样例数据集下载:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/*" --local_dir ./data/diffsynth_example_dataset
```
[stable-diffusion-xl-base-1.0 训练脚本](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。

View File

@@ -0,0 +1,138 @@
# Stable Diffusion
Stable Diffusion 是由 Stability AI 开发的开源扩散式文本到图像生成模型,支持 512x512 分辨率的文本到图像生成。
## 安装
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
## 快速开始
运行以下代码可以快速加载 [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 2GB 显存即可运行。
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image.jpg")
```
## 模型总览
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
## 模型推理
模型通过 `StableDiffusionPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
`StableDiffusionPipeline` 的推理输入参数包括:
* `prompt`: 文本提示词。
* `negative_prompt`: 负面提示词,默认为空字符串。
* `cfg_scale`: Classifier-Free Guidance 缩放系数,默认 7.5。
* `height`: 输出图像高度,默认 512。
* `width`: 输出图像宽度,默认 512。
* `seed`: 随机种子,默认不设置时使用随机种子。
* `rand_device`: 噪声生成设备,默认 "cpu"。
* `num_inference_steps`: 推理步数,默认 50。
* `eta`: DDIM 调度器的 eta 参数,默认 0.0。
* `guidance_rescale`: Guidance rescale 系数,默认 0.0。
* `progress_bar_cmd`: 进度条回调函数。
## 模型训练
stable_diffusion 系列模型通过 `examples/stable_diffusion/model_training/train.py` 进行训练,脚本的参数包括:
* 通用训练参数
* 数据集基础配置
* `--dataset_base_path`: 数据集的根目录。
* `--dataset_metadata_path`: 数据集的元数据文件路径。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
* 模型加载配置
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
* 训练基础配置
* `--learning_rate`: 学习率。
* `--num_epochs`: 轮数Epoch
* `--trainable_models`: 可训练的模型,例如 `dit``vae``text_encoder`
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
* `--weight_decay`: 权重衰减大小。
* `--task`: 训练任务,默认为 `sft`
* 输出配置
* `--output_path`: 模型保存路径。
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
* `--save_steps`: 保存模型的训练步数间隔。
* LoRA 配置
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪些层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`
* 梯度配置
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
* `--gradient_accumulation_steps`: 梯度累积步数。
* 分辨率配置
* `--height`: 图像/视频的高度。留空启用动态分辨率。
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
* `--num_frames`: 视频的帧数(仅视频生成模型)。
* Stable Diffusion 专有参数
* `--tokenizer_path`: Tokenizer 路径,默认为 `AI-ModelScope/stable-diffusion-v1-5:tokenizer/`
样例数据集下载:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/*" --local_dir ./data/diffsynth_example_dataset
```
[stable-diffusion-v1-5 训练脚本](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。

View File

@@ -32,6 +32,8 @@
Model_Details/LTX-2
Model_Details/ERNIE-Image
Model_Details/JoyAI-Image
Model_Details/Stable-Diffusion
Model_Details/Stable-Diffusion-XL
.. toctree::
:maxdepth: 2

View File

@@ -1,331 +0,0 @@
import importlib, inspect, pkgutil, traceback, torch, os, re, typing
from typing import Union, List, Optional, Tuple, Iterable, Dict, Literal
from contextlib import contextmanager
from diffsynth.utils.data import VideoData
import streamlit as st
from diffsynth import ModelConfig
from diffsynth.diffusion.base_pipeline import ControlNetInput
from PIL import Image
from tqdm import tqdm
st.set_page_config(layout="wide")
class StreamlitTqdmWrapper:
"""Wrapper class that combines tqdm and streamlit progress bar"""
def __init__(self, iterable, st_progress_bar=None):
self.iterable = iterable
self.st_progress_bar = st_progress_bar
self.tqdm_bar = tqdm(iterable)
self.total = len(iterable) if hasattr(iterable, '__len__') else None
self.current = 0
def __iter__(self):
for item in self.tqdm_bar:
if self.st_progress_bar is not None and self.total is not None:
self.current += 1
self.st_progress_bar.progress(self.current / self.total)
yield item
def __enter__(self):
return self
def __exit__(self, *args):
if hasattr(self.tqdm_bar, '__exit__'):
self.tqdm_bar.__exit__(*args)
@contextmanager
def catch_error(error_value):
try:
yield
except Exception as e:
error_message = traceback.format_exc()
print(f"Error {error_value}:\n{error_message}")
def parse_model_configs_from_an_example(path):
model_configs = []
with open(path, "r") as f:
for code in f.readlines():
code = code.strip()
if not code.startswith("ModelConfig"):
continue
pairs = re.findall(r'(\w+)\s*=\s*["\']([^"\']+)["\']', code)
config_dict = {k: v for k, v in pairs}
model_configs.append(ModelConfig(model_id=config_dict["model_id"], origin_file_pattern=config_dict["origin_file_pattern"]))
return model_configs
def list_examples(path, keyword=None):
examples = []
if os.path.isdir(path):
for file_name in os.listdir(path):
examples.extend(list_examples(os.path.join(path, file_name), keyword=keyword))
elif path.endswith(".py"):
with open(path, "r") as f:
code = f.read()
if keyword is None or keyword in code:
examples.extend([path])
return examples
def parse_available_pipelines():
from diffsynth.diffusion.base_pipeline import BasePipeline
import diffsynth.pipelines as _pipelines_pkg
available_pipelines = {}
for _, name, _ in pkgutil.iter_modules(_pipelines_pkg.__path__):
with catch_error(f"Failed: import diffsynth.pipelines.{name}"):
mod = importlib.import_module(f"diffsynth.pipelines.{name}")
classes = {
cls_name: cls for cls_name, cls in inspect.getmembers(mod, inspect.isclass)
if issubclass(cls, BasePipeline) and cls is not BasePipeline and cls.__module__ == mod.__name__
}
available_pipelines.update(classes)
return available_pipelines
def parse_available_examples(path, available_pipelines):
available_examples = {}
for pipeline_name in available_pipelines:
examples = ["None"] + list_examples(path, keyword=f"{pipeline_name}.from_pretrained")
available_examples[pipeline_name] = examples
return available_examples
def draw_selectbox(label, options, option_map, value=None, disabled=False):
default_index = 0 if value is None else tuple(options).index([option for option in option_map if option_map[option]==value][0])
option = st.selectbox(label=label, options=tuple(options), index=default_index, disabled=disabled)
return option_map.get(option)
def parse_params(fn):
params = []
for name, param in inspect.signature(fn).parameters.items():
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else None
default = param.default if param.default is not inspect.Parameter.empty else None
params.append({"name": name, "dtype": annotation, "value": default})
return params
def draw_model_config(model_config=None, key_suffix="", disabled=False):
with st.container(border=True):
if model_config is None:
model_config = ModelConfig()
path = st.text_input(label="path", key="path" + key_suffix, value=model_config.path, disabled=disabled)
col1, col2 = st.columns(2)
with col1:
model_id = st.text_input(label="model_id", key="model_id" + key_suffix, value=model_config.model_id, disabled=disabled)
with col2:
origin_file_pattern = st.text_input(label="origin_file_pattern", key="origin_file_pattern" + key_suffix, value=model_config.origin_file_pattern, disabled=disabled)
model_config = ModelConfig(
path=None if path == "" else path,
model_id=model_id,
origin_file_pattern=origin_file_pattern,
)
return model_config
def draw_multi_model_config(name="", value=None, disabled=False):
model_configs = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
model_config = draw_model_config(key_suffix=f"_{name}_{i}", model_config=None if value is None else value[i], disabled=disabled)
model_configs.append(model_config)
return model_configs
def draw_single_model_config(name="", value=None, disabled=False):
with st.container(border=True):
st.markdown(name)
model_config = draw_model_config(value, key_suffix=f"_{name}", disabled=disabled)
return model_config
def draw_multi_images(name="", value=None, disabled=False):
images = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
image = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], key=f"{name}_{i}", disabled=disabled)
if image is not None: images.append(Image.open(image))
return images
def draw_multi_elements(st_element, name="", value=None, disabled=False, kwargs=None):
if kwargs is None:
kwargs = {}
elements = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
element = st_element(name, key=f"{name}_{i}", disabled=disabled, value=None if value is None else value[i], **kwargs)
elements.append(element)
return elements
def draw_controlnet_input(name="", value=None, disabled=False):
with st.container(border=True):
st.markdown(name)
controlnet_id = st.number_input("controlnet_id", value=0, min_value=0, max_value=20, step=1, key=f"{name}_controlnet_id")
scale = st.number_input("scale", value=1.0, min_value=0.0, max_value=10.0, key=f"{name}_scale")
image = st.file_uploader("image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_image")
if image is not None: image = Image.open(image)
inpaint_image = st.file_uploader("inpaint_image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_image")
if inpaint_image is not None: inpaint_image = Image.open(inpaint_image)
inpaint_mask = st.file_uploader("inpaint_mask", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_mask")
if inpaint_mask is not None: inpaint_mask = Image.open(inpaint_mask)
return ControlNetInput(controlnet_id=controlnet_id, scale=scale, image=image, inpaint_image=inpaint_image, inpaint_mask=inpaint_mask)
def draw_controlnet_inputs(name, value=None, disabled=False):
controlnet_inputs = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
controlnet_input = draw_controlnet_input(name=f"{name}_{i}", value=None, disabled=disabled)
controlnet_inputs.append(controlnet_input)
return controlnet_inputs
def draw_ui_element(name, dtype, value):
unsupported_dtype = [
Dict[str, torch.Tensor],
torch.Tensor,
]
if dtype in unsupported_dtype:
return
if value is None:
with st.container(border=True):
enable = st.checkbox(f"Enable {name}", value=False)
ui = draw_ui_element_safely(name, dtype, value=value, disabled=not enable)
if enable:
return ui
else:
return None
else:
return draw_ui_element_safely(name, dtype, value)
def draw_video(name, value=None, disabled=False):
ui = st.file_uploader(name, type=["mp4"], disabled=disabled)
if ui is not None:
ui = VideoData(ui)
ui = [ui[i] for i in range(len(ui))]
return ui
def draw_ui_element_safely(name, dtype, value, disabled=False):
if dtype == torch.dtype:
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
elif dtype == Union[str, torch.device]:
option_map = {"cuda": "cuda", "cpu": "cpu"}
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
elif dtype == bool:
ui = st.checkbox(name, value=value, disabled=disabled)
elif dtype == ModelConfig:
ui = draw_single_model_config(name, value=value, disabled=disabled)
elif dtype in [list[ModelConfig], List[ModelConfig], Union[list[ModelConfig], ModelConfig, str]]:
if name == "model_configs" and "model_configs_from_example" in st.session_state:
model_configs = st.session_state["model_configs_from_example"]
del st.session_state["model_configs_from_example"]
ui = draw_multi_model_config(name, model_configs, disabled=disabled)
else:
ui = draw_multi_model_config(name, disabled=disabled)
elif dtype == str:
if "prompt" in name:
ui = st.text_area(name, value=value, height=3, disabled=disabled)
else:
ui = st.text_input(name, value=value, disabled=disabled)
elif dtype == float:
ui = st.number_input(name, value=value, disabled=disabled)
elif dtype == int:
ui = st.number_input(name, value=value, step=1, disabled=disabled)
elif dtype == Image.Image:
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
if ui is not None: ui = Image.open(ui)
elif dtype in [List[Image.Image], list[Image.Image], Union[list[Image.Image], Image.Image], Union[List[Image.Image], Image.Image]]:
if "video" in name:
ui = draw_video(name, value=value, disabled=disabled)
else:
ui = draw_multi_images(name, value=value, disabled=disabled)
elif dtype in [List[ControlNetInput], list[ControlNetInput]]:
ui = draw_controlnet_inputs(name, value=value, disabled=disabled)
elif dtype in [List[str], list[str]]:
ui = draw_multi_elements(st.text_input, name, value=value, disabled=disabled)
elif dtype in [List[float], list[float], Union[list[float], float], Union[List[float], float]]:
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled)
elif dtype in [List[int], list[int]]:
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled, kwargs={"step": 1})
elif dtype in [List[List[Image.Image]], list[list[Image.Image]]]:
ui = draw_multi_elements(draw_video, name, value=value, disabled=disabled)
elif dtype in [tuple[int, int], Tuple[int, int]]:
with st.container(border=True):
st.markdown(name)
ui = (st.text_input(f"{name}_0", value=value[0], disabled=disabled), st.text_input(f"{name}_1", value=value[1], disabled=disabled))
elif isinstance(dtype, typing._LiteralGenericAlias):
with st.container(border=True):
st.markdown(f"{name} ({dtype})")
ui = st.text_input(name, value=value, disabled=disabled, label_visibility="hidden")
elif dtype is None:
if name == "progress_bar_cmd":
ui = value
else:
st.markdown(f"(`{name}` is not not configurable in WebUI). dtype: `{dtype}`.")
ui = value
return ui
def launch_webui():
input_col, output_col = st.columns(2)
with input_col:
if "available_pipelines" not in st.session_state:
st.session_state["available_pipelines"] = parse_available_pipelines()
if "available_examples" not in st.session_state:
st.session_state["available_examples"] = parse_available_examples("./examples", st.session_state["available_pipelines"])
with st.expander("Pipeline", expanded=True):
pipeline_class = draw_selectbox("Pipeline Class", st.session_state["available_pipelines"].keys(), st.session_state["available_pipelines"], value=st.session_state["available_pipelines"]["ZImagePipeline"])
example = st.selectbox("Parse model configs from an example (optional)", st.session_state["available_examples"][pipeline_class.__name__])
# Clear if pipeline is changed
if "prev_pipeline_class" in st.session_state and st.session_state["prev_pipeline_class"] != pipeline_class:
if "pipeline_class" in st.session_state: del st.session_state["pipeline_class"]
if "model_configs_from_example" in st.session_state: del st.session_state["model_configs_from_example"]
if "prev_example" in st.session_state and st.session_state["prev_example"] != example:
if "model_configs_from_example" in st.session_state: del st.session_state["model_configs_from_example"]
st.session_state["prev_pipeline_class"] = pipeline_class
st.session_state["prev_example"] = example
if st.button("Step 1: Parse Pipeline", type="primary"):
st.session_state["pipeline_class"] = pipeline_class
if example != "None":
st.session_state["model_configs_from_example"] = parse_model_configs_from_an_example(example)
if "pipeline_class" not in st.session_state:
return
with st.expander("Model", expanded=True):
input_params = {}
params = parse_params(pipeline_class.from_pretrained)
for param in params:
input_params[param["name"]] = draw_ui_element(**param)
if st.button("Step 2: Load Models", type="primary"):
with st.spinner("Loading models", show_time=True):
if "pipe" in st.session_state:
del st.session_state["pipe"]
torch.cuda.empty_cache()
st.session_state["pipe"] = pipeline_class.from_pretrained(**input_params)
if "pipe" not in st.session_state:
return
with st.expander("Input", expanded=True):
pipe = st.session_state["pipe"]
input_params = {}
params = parse_params(pipeline_class.__call__)
for param in params:
if param["name"] in ["self"]:
continue
input_params[param["name"]] = draw_ui_element(**param)
with output_col:
if st.button("Step 3: Generate", type="primary"):
if "progress_bar_cmd" in input_params:
input_params["progress_bar_cmd"] = lambda iterable: StreamlitTqdmWrapper(iterable, st.progress(0))
result = pipe(**input_params)
st.session_state["result"] = result
if "result" in st.session_state:
result = st.session_state["result"]
if isinstance(result, Image.Image):
st.image(result)
else:
print(f"unsupported result format: {result}")
launch_webui()

View File

@@ -0,0 +1,25 @@
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image.jpg")

View File

@@ -0,0 +1,36 @@
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image.jpg")

View File

@@ -0,0 +1,15 @@
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/stable-diffusion-v1-5/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/stable_diffusion/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5 \
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5/metadata.csv \
--height 512 \
--width 512 \
--dataset_repeat 50 \
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "unet" \
--remove_prefix_in_ckpt "pipe.unet." \
--output_path "./models/train/stable-diffusion-v1-5_full" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,17 @@
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/stable-diffusion-v1-5/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/stable_diffusion/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5 \
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5/metadata.csv \
--height 512 \
--width 512 \
--dataset_repeat 50 \
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.unet." \
--output_path "./models/train/stable-diffusion-v1-5_lora" \
--lora_base_model "unet" \
--lora_target_modules "" \
--lora_rank 32 \
--use_gradient_checkpointing

View File

@@ -0,0 +1,142 @@
import torch, os, argparse, accelerate
from diffsynth.core import UnifiedDataset
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
from diffsynth.diffusion import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class StableDiffusionTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
tokenizer_path=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
fp8_models=None,
offload_models=None,
device="cpu",
task="sft",
):
super().__init__()
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"))
self.pipe = StableDiffusionPipeline.from_pretrained(torch_dtype=torch.float32, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
preset_lora_path, preset_lora_model,
task=task,
)
# Other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"direct_distill:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
}
def get_pipeline_inputs(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"cfg_scale": 1,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
if inputs is None: inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
loss = self.task_to_loss[self.task](self.pipe, *inputs)
return loss
def parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser = add_general_config(parser)
parser = add_image_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
return parser
if __name__ == "__main__":
parser = parser()
args = parser.parse_args()
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
)
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","),
main_data_operator=UnifiedDataset.default_image_operator(
base_path=args.dataset_base_path,
max_pixels=args.max_pixels,
height=args.height,
width=args.width,
height_division_factor=32,
width_division_factor=32,
)
)
model = StableDiffusionTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
tokenizer_path=args.tokenizer_path,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
preset_lora_path=args.preset_lora_path,
preset_lora_model=args.preset_lora_model,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
fp8_models=args.fp8_models,
offload_models=args.offload_models,
task=args.task,
device=accelerator.device,
)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
launcher_map = {
"sft:data_process": launch_data_process_task,
"direct_distill:data_process": launch_data_process_task,
"sft": launch_training_task,
"sft:train": launch_training_task,
"direct_distill": launch_training_task,
"direct_distill:train": launch_training_task,
}
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)

View File

@@ -0,0 +1,27 @@
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
from diffsynth.core import load_state_dict
import torch
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("./models/train/stable-diffusion-v1-5_full/epoch-1.safetensors", torch_dtype=torch.float32)
pipe.unet.load_state_dict(state_dict)
image = pipe(
prompt="a dog",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image_stable-diffusion-v1-5_full.jpg")

View File

@@ -0,0 +1,26 @@
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.unet, "models/train/stable-diffusion-v1-5_lora/epoch-4.safetensors")
image = pipe(
prompt="a dog",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image_stable-diffusion-v1-5.jpg")

View File

@@ -0,0 +1,26 @@
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars",
negative_prompt="",
cfg_scale=5.0,
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
)
image.save("image.jpg")

View File

@@ -0,0 +1,37 @@
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionXLPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars",
negative_prompt="",
cfg_scale=5.0,
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
)
image.save("image.jpg")

View File

@@ -0,0 +1,15 @@
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/stable-diffusion-xl-base-1.0/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/stable_diffusion_xl/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0 \
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0/metadata.csv \
--height 1024 \
--width 1024 \
--dataset_repeat 10 \
--model_id_with_origin_paths "stabilityai/stable-diffusion-xl-base-1.0:text_encoder/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:text_encoder_2/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:unet/diffusion_pytorch_model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "unet" \
--remove_prefix_in_ckpt "pipe.unet." \
--output_path "./models/train/stable-diffusion-xl-base-1.0_full" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,17 @@
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/stable-diffusion-xl-base-1.0/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/stable_diffusion_xl/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0 \
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0/metadata.csv \
--height 1024 \
--width 1024 \
--dataset_repeat 10 \
--model_id_with_origin_paths "stabilityai/stable-diffusion-xl-base-1.0:text_encoder/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:text_encoder_2/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:unet/diffusion_pytorch_model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.unet." \
--output_path "./models/train/stable-diffusion-xl-base-1.0_lora" \
--lora_base_model "unet" \
--lora_target_modules "" \
--lora_rank 32 \
--use_gradient_checkpointing

View File

@@ -0,0 +1,144 @@
import torch, os, argparse, accelerate
from diffsynth.core import UnifiedDataset
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
from diffsynth.diffusion import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class StableDiffusionXLTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
tokenizer_path=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
fp8_models=None,
offload_models=None,
device="cpu",
task="sft",
):
super().__init__()
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"))
tokenizer_2_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"))
self.pipe = StableDiffusionXLPipeline.from_pretrained(torch_dtype=torch.float32, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, tokenizer_2_config=tokenizer_2_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
preset_lora_path, preset_lora_model,
task=task,
)
# Other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"direct_distill:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
}
def get_pipeline_inputs(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"cfg_scale": 1,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
if inputs is None: inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
loss = self.task_to_loss[self.task](self.pipe, *inputs)
return loss
def parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser = add_general_config(parser)
parser = add_image_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--tokenizer_2_path", type=str, default=None, help="Path to tokenizer 2.")
return parser
if __name__ == "__main__":
parser = parser()
args = parser.parse_args()
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
)
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","),
main_data_operator=UnifiedDataset.default_image_operator(
base_path=args.dataset_base_path,
max_pixels=args.max_pixels,
height=args.height,
width=args.width,
height_division_factor=32,
width_division_factor=32,
)
)
model = StableDiffusionXLTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
tokenizer_path=args.tokenizer_path,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
preset_lora_path=args.preset_lora_path,
preset_lora_model=args.preset_lora_model,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
fp8_models=args.fp8_models,
offload_models=args.offload_models,
task=args.task,
device=accelerator.device,
)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
launcher_map = {
"sft:data_process": launch_data_process_task,
"direct_distill:data_process": launch_data_process_task,
"sft": launch_training_task,
"sft:train": launch_training_task,
"direct_distill": launch_training_task,
"direct_distill:train": launch_training_task,
}
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)

View File

@@ -0,0 +1,28 @@
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
from diffsynth.core import load_state_dict
import torch
pipe = StableDiffusionXLPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
)
state_dict = load_state_dict("./models/train/stable-diffusion-xl-base-1.0_full/epoch-1.safetensors", torch_dtype=torch.float32)
pipe.unet.load_state_dict(state_dict)
image = pipe(
prompt="a dog",
negative_prompt="",
cfg_scale=7.0,
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
)
image.save("image_stable-diffusion-xl-base-1.0_full.jpg")

View File

@@ -0,0 +1,27 @@
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
)
pipe.load_lora(pipe.unet, "models/train/stable-diffusion-xl-base-1.0_lora/epoch-4.safetensors")
image = pipe(
prompt="a dog",
negative_prompt="",
cfg_scale=7.0,
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
)
image.save("image_stable-diffusion-xl-base-1.0.jpg")