mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
Compare commits
12 Commits
version2.0
...
sd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5934719f8 | ||
|
|
54345f8678 | ||
|
|
2d7d5137ea | ||
|
|
3799bdc23a | ||
|
|
5cdab9ed01 | ||
|
|
a8a0f082bb | ||
|
|
9453700a30 | ||
|
|
82e482286c | ||
|
|
5c89a15b9a | ||
|
|
079e51c9f3 | ||
|
|
8f18e24597 | ||
|
|
45d973e87d |
325
README.md
325
README.md
@@ -34,6 +34,10 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|||||||
|
|
||||||
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
||||||
|
|
||||||
|
- **April 24, 2026** We add support for Stable Diffusion v1.5 and SDXL, including inference, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/Stable-Diffusion.md), [documentation](/docs/en/Model_Details/Stable-Diffusion-XL.md) and [example code](/examples/stable_diffusion/).
|
||||||
|
|
||||||
|
- **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/).
|
||||||
|
|
||||||
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
|
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
|
||||||
|
|
||||||
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
|
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
|
||||||
@@ -297,6 +301,129 @@ Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### Stable Diffusion: [/docs/en/Model_Details/Stable-Diffusion.md](/docs/en/Model_Details/Stable-Diffusion.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Quick Start</summary>
|
||||||
|
|
||||||
|
Running the following code will quickly load the [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 2GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Examples</summary>
|
||||||
|
|
||||||
|
Example code for Stable Diffusion is available at: [/examples/stable_diffusion/](/examples/stable_diffusion/)
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### Stable Diffusion XL: [/docs/en/Model_Details/Stable-Diffusion-XL.md](/docs/en/Model_Details/Stable-Diffusion-XL.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Quick Start</summary>
|
||||||
|
|
||||||
|
Running the following code will quickly load the [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 6GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Examples</summary>
|
||||||
|
|
||||||
|
Example code for Stable Diffusion XL is available at: [/examples/stable_diffusion_xl/](/examples/stable_diffusion_xl/)
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
#### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md)
|
#### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md)
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -598,6 +725,143 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Quick Start</summary>
|
||||||
|
|
||||||
|
Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = ErnieImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device='cuda',
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="一只黑白相间的中华田园犬",
|
||||||
|
negative_prompt="",
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
cfg_scale=4.0,
|
||||||
|
)
|
||||||
|
image.save("output.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Examples</summary>
|
||||||
|
|
||||||
|
Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
|
||||||
|
|
||||||
|
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
||||||
|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### JoyAI-Image: [/docs/en/Model_Details/JoyAI-Image.md](/docs/en/Model_Details/JoyAI-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Quick Start</summary>
|
||||||
|
|
||||||
|
Running the following code will quickly load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 4GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
# Download dataset
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||||
|
)
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe = JoyAIImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use first sample from dataset
|
||||||
|
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||||
|
prompt = "将裙子改为粉色"
|
||||||
|
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||||
|
|
||||||
|
output = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=0,
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
output.save("output_joyai_edit_low_vram.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Examples</summary>
|
||||||
|
|
||||||
|
Example code for JoyAI-Image is available at: [/examples/joyai_image/](/examples/joyai_image/)
|
||||||
|
|
||||||
|
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
### Video Synthesis
|
### Video Synthesis
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||||
@@ -877,67 +1141,6 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
|
|
||||||
|
|
||||||
<details>
|
|
||||||
|
|
||||||
<summary>Quick Start</summary>
|
|
||||||
|
|
||||||
Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
|
||||||
import torch
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": torch.bfloat16,
|
|
||||||
"offload_device": "cpu",
|
|
||||||
"onload_dtype": torch.bfloat16,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
pipe = ErnieImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device='cuda',
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
||||||
)
|
|
||||||
|
|
||||||
image = pipe(
|
|
||||||
prompt="一只黑白相间的中华田园犬",
|
|
||||||
negative_prompt="",
|
|
||||||
height=1024,
|
|
||||||
width=1024,
|
|
||||||
seed=42,
|
|
||||||
num_inference_steps=50,
|
|
||||||
cfg_scale=4.0,
|
|
||||||
)
|
|
||||||
image.save("output.jpg")
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
|
|
||||||
<summary>Examples</summary>
|
|
||||||
|
|
||||||
Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
|
|
||||||
|
|
||||||
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
|
||||||
|-|-|-|-|-|-|-|
|
|
||||||
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
|
||||||
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## Innovative Achievements
|
## Innovative Achievements
|
||||||
|
|
||||||
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
|
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
|
||||||
|
|||||||
325
README_zh.md
325
README_zh.md
@@ -34,6 +34,10 @@ DiffSynth 目前包括两个开源项目:
|
|||||||
|
|
||||||
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||||
|
|
||||||
|
- **2026年4月24日** 我们新增对 Stable Diffusion v1.5 和 SDXL 的支持,包括推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/Stable-Diffusion.md)和[示例代码](/examples/stable_diffusion/)。
|
||||||
|
|
||||||
|
- **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。
|
||||||
|
|
||||||
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
|
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
|
||||||
|
|
||||||
- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持,模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting,框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。
|
- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持,模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting,框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。
|
||||||
@@ -297,6 +301,129 @@ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### Stable Diffusion:[/docs/zh/Model_Details/Stable-Diffusion.md](/docs/zh/Model_Details/Stable-Diffusion.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) 模型并进行推理。显存管理已启用,框架会自动根据剩余显存控制模型参数的加载,最低 2GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Stable Diffusion 的示例代码位于:[/examples/stable_diffusion/](/examples/stable_diffusion/)
|
||||||
|
|
||||||
|
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### Stable Diffusion XL:[/docs/zh/Model_Details/Stable-Diffusion-XL.md](/docs/zh/Model_Details/Stable-Diffusion-XL.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) 模型并进行推理。显存管理已启用,框架会自动根据剩余显存控制模型参数的加载,最低 6GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Stable Diffusion XL 的示例代码位于:[/examples/stable_diffusion_xl/](/examples/stable_diffusion_xl/)
|
||||||
|
|
||||||
|
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
|
#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -598,6 +725,143 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = ErnieImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device='cuda',
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="一只黑白相间的中华田园犬",
|
||||||
|
negative_prompt="",
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
cfg_scale=4.0,
|
||||||
|
)
|
||||||
|
image.save("output.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/)
|
||||||
|
|
||||||
|
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
||||||
|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### JoyAI-Image: [/docs/zh/Model_Details/JoyAI-Image.md](/docs/zh/Model_Details/JoyAI-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
# Download dataset
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||||
|
)
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe = JoyAIImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use first sample from dataset
|
||||||
|
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||||
|
prompt = "将裙子改为粉色"
|
||||||
|
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||||
|
|
||||||
|
output = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=0,
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
output.save("output_joyai_edit_low_vram.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
JoyAI-Image 的示例代码位于:[/examples/joyai_image/](/examples/joyai_image/)
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
### 视频生成模型
|
### 视频生成模型
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||||
@@ -877,67 +1141,6 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md)
|
|
||||||
|
|
||||||
<details>
|
|
||||||
|
|
||||||
<summary>快速开始</summary>
|
|
||||||
|
|
||||||
运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
|
|
||||||
|
|
||||||
```python
|
|
||||||
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
|
||||||
import torch
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": torch.bfloat16,
|
|
||||||
"offload_device": "cpu",
|
|
||||||
"onload_dtype": torch.bfloat16,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
pipe = ErnieImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device='cuda',
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
||||||
)
|
|
||||||
|
|
||||||
image = pipe(
|
|
||||||
prompt="一只黑白相间的中华田园犬",
|
|
||||||
negative_prompt="",
|
|
||||||
height=1024,
|
|
||||||
width=1024,
|
|
||||||
seed=42,
|
|
||||||
num_inference_steps=50,
|
|
||||||
cfg_scale=4.0,
|
|
||||||
)
|
|
||||||
image.save("output.jpg")
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
|
|
||||||
<summary>示例代码</summary>
|
|
||||||
|
|
||||||
ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/)
|
|
||||||
|
|
||||||
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|
|
||||||
|-|-|-|-|-|-|-|
|
|
||||||
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
|
||||||
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## 创新成果
|
## 创新成果
|
||||||
|
|
||||||
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
|
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
|
||||||
|
|||||||
@@ -900,4 +900,75 @@ mova_series = [
|
|||||||
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
|
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series
|
stable_diffusion_xl_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
|
||||||
|
"model_name": "stable_diffusion_xl_unet",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
|
||||||
|
"extra_kwargs": {"attention_head_dim": [5, 10, 20], "transformer_layers_per_block": [1, 2, 10], "use_linear_projection": True, "addition_embed_type": "text_time", "addition_time_embed_dim": 256, "projection_class_embeddings_input_dim": 2816},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
|
||||||
|
"model_hash": "98cc34ccc5b54ae0e56bdea8688dcd5a",
|
||||||
|
"model_name": "stable_diffusion_xl_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_xl_text_encoder.SDXLTextEncoder2StateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors")
|
||||||
|
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
||||||
|
"model_name": "stable_diffusion_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "13115dd45a6e1c39860f91ab073b8a78",
|
||||||
|
"model_name": "stable_diffusion_xl_vae",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter",
|
||||||
|
"extra_kwargs": {"scaling_factor": 0.13025, "sample_size": 1024, "force_upcast": True},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
stable_diffusion_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors")
|
||||||
|
"model_hash": "ffd1737ae9df7fd43f5fbed653bdad67",
|
||||||
|
"model_name": "stable_diffusion_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "f86d5683ed32433be8ca69969c67ba69",
|
||||||
|
"model_name": "stable_diffusion_vae",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "025a4b86a84829399d89f613e580757b",
|
||||||
|
"model_name": "stable_diffusion_unet",
|
||||||
|
"model_class": "diffsynth.models.stable_diffusion_unet.UNet2DConditionModel",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
joyai_image_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth")
|
||||||
|
"model_hash": "56592ddfd7d0249d3aa527d24161a863",
|
||||||
|
"model_name": "joyai_image_dit",
|
||||||
|
"model_class": "diffsynth.models.joyai_image_dit.JoyAIImageDiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model-*.safetensors")
|
||||||
|
"model_hash": "2d11bf14bba8b4e87477c8199a895403",
|
||||||
|
"model_name": "joyai_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.joyai_image_text_encoder.JoyAIImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
MODEL_CONFIGS = stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series
|
||||||
|
|||||||
@@ -279,6 +279,61 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|||||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
"transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
},
|
},
|
||||||
|
"diffsynth.models.joyai_image_dit.Transformer3DModel": {
|
||||||
|
"diffsynth.models.joyai_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.joyai_image_dit.ModulateWan": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionModel": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.stable_diffusion_unet.UNet2DConditionModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.stable_diffusion_vae.StableDiffusionVAE": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.stable_diffusion_vae.Upsample2D": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.stable_diffusion_vae.Downsample2D": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def QwenImageTextEncoder_Module_Map_Updater():
|
def QwenImageTextEncoder_Module_Map_Updater():
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import math
|
import math, warnings
|
||||||
import torch, torchvision, imageio, os
|
import torch, torchvision, imageio, os
|
||||||
import imageio.v3 as iio
|
import imageio.v3 as iio
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -260,15 +260,19 @@ class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
|
|||||||
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
|
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
|
||||||
|
|
||||||
def __call__(self, data: str):
|
def __call__(self, data: str):
|
||||||
reader = self.get_reader(data)
|
try:
|
||||||
num_frames = self.get_num_frames(reader)
|
reader = self.get_reader(data)
|
||||||
duration = num_frames / self.frame_rate
|
num_frames = self.get_num_frames(reader)
|
||||||
waveform, sample_rate = torchaudio.load(data)
|
duration = num_frames / self.frame_rate
|
||||||
target_samples = int(duration * sample_rate)
|
waveform, sample_rate = torchaudio.load(data)
|
||||||
current_samples = waveform.shape[-1]
|
target_samples = int(duration * sample_rate)
|
||||||
if current_samples > target_samples:
|
current_samples = waveform.shape[-1]
|
||||||
waveform = waveform[..., :target_samples]
|
if current_samples > target_samples:
|
||||||
elif current_samples < target_samples:
|
waveform = waveform[..., :target_samples]
|
||||||
padding = target_samples - current_samples
|
elif current_samples < target_samples:
|
||||||
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
padding = target_samples - current_samples
|
||||||
return waveform, sample_rate
|
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||||
|
return waveform, sample_rate
|
||||||
|
except:
|
||||||
|
warnings.warn(f"Cannot load audio in {data}. The audio will be `None`.")
|
||||||
|
return None
|
||||||
|
|||||||
107
diffsynth/diffusion/ddim_scheduler.py
Normal file
107
diffsynth/diffusion/ddim_scheduler.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import torch, math
|
||||||
|
|
||||||
|
|
||||||
|
class DDIMScheduler():
|
||||||
|
|
||||||
|
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
|
||||||
|
self.num_train_timesteps = num_train_timesteps
|
||||||
|
if beta_schedule == "scaled_linear":
|
||||||
|
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
|
||||||
|
elif beta_schedule == "linear":
|
||||||
|
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||||
|
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
|
||||||
|
if rescale_zero_terminal_snr:
|
||||||
|
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
|
||||||
|
self.alphas_cumprod = self.alphas_cumprod.tolist()
|
||||||
|
self.set_timesteps(10)
|
||||||
|
self.prediction_type = prediction_type
|
||||||
|
self.training = False
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_zero_terminal_snr(self, alphas_cumprod):
|
||||||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
|
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Convert alphas_bar_sqrt to betas
|
||||||
|
alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
|
||||||
|
|
||||||
|
return alphas_bar
|
||||||
|
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps, denoising_strength=1.0, training=False, **kwargs):
|
||||||
|
# The timesteps are aligned to 999...0, which is different from other implementations,
|
||||||
|
# but I think this implementation is more reasonable in theory.
|
||||||
|
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
||||||
|
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
||||||
|
if num_inference_steps == 1:
|
||||||
|
self.timesteps = torch.Tensor([max_timestep])
|
||||||
|
else:
|
||||||
|
step_length = max_timestep / (num_inference_steps - 1)
|
||||||
|
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
|
||||||
|
self.training = training
|
||||||
|
|
||||||
|
|
||||||
|
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
||||||
|
if self.prediction_type == "epsilon":
|
||||||
|
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
|
||||||
|
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
|
||||||
|
prev_sample = sample * weight_x + model_output * weight_e
|
||||||
|
elif self.prediction_type == "v_prediction":
|
||||||
|
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
|
||||||
|
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
|
||||||
|
prev_sample = sample * weight_x + model_output * weight_e
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{self.prediction_type} is not implemented")
|
||||||
|
return prev_sample
|
||||||
|
|
||||||
|
|
||||||
|
def step(self, model_output, timestep, sample, to_final=False):
|
||||||
|
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||||
|
alpha_prod_t_prev = 1.0
|
||||||
|
else:
|
||||||
|
timestep_prev = int(self.timesteps[timestep_id + 1])
|
||||||
|
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
|
||||||
|
|
||||||
|
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
|
||||||
|
|
||||||
|
|
||||||
|
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||||
|
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
||||||
|
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
|
||||||
|
return noise_pred
|
||||||
|
|
||||||
|
|
||||||
|
def add_noise(self, original_samples, noise, timestep):
|
||||||
|
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||||
|
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||||
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||||
|
return noisy_samples
|
||||||
|
|
||||||
|
|
||||||
|
def training_target(self, sample, noise, timestep):
|
||||||
|
if self.prediction_type == "epsilon":
|
||||||
|
return noise
|
||||||
|
else:
|
||||||
|
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||||
|
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||||
|
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
|
def training_weight(self, timestep):
|
||||||
|
return 1.0
|
||||||
@@ -159,6 +159,18 @@ class FlowMatchScheduler():
|
|||||||
timesteps[timestep_id] = timestep
|
timesteps[timestep_id] = timestep
|
||||||
return sigmas, timesteps
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_joyai_image(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
shift = 4.0 if shift is None else shift
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
||||||
num_train_timesteps = 1000
|
num_train_timesteps = 1000
|
||||||
|
|||||||
@@ -33,15 +33,15 @@ def launch_training_task(
|
|||||||
for epoch_id in range(num_epochs):
|
for epoch_id in range(num_epochs):
|
||||||
for data in tqdm(dataloader):
|
for data in tqdm(dataloader):
|
||||||
with accelerator.accumulate(model):
|
with accelerator.accumulate(model):
|
||||||
optimizer.zero_grad()
|
|
||||||
if dataset.load_from_cache:
|
if dataset.load_from_cache:
|
||||||
loss = model({}, inputs=data)
|
loss = model({}, inputs=data)
|
||||||
else:
|
else:
|
||||||
loss = model(data)
|
loss = model(data)
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||||
if save_steps is None:
|
if save_steps is None:
|
||||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||||
model_logger.on_training_end(accelerator, model, save_steps)
|
model_logger.on_training_end(accelerator, model, save_steps)
|
||||||
|
|||||||
636
diffsynth/models/joyai_image_dit.py
Normal file
636
diffsynth/models/joyai_image_dit.py
Normal file
@@ -0,0 +1,636 @@
|
|||||||
|
import math
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from ..core.attention import attention_forward
|
||||||
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestep_embedding(
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
embedding_dim: int,
|
||||||
|
flip_sin_to_cos: bool = False,
|
||||||
|
downscale_freq_shift: float = 1,
|
||||||
|
scale: float = 1,
|
||||||
|
max_period: int = 10000,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
exponent = -math.log(max_period) * torch.arange(
|
||||||
|
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||||
|
)
|
||||||
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
emb = scale * emb
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||||
|
if flip_sin_to_cos:
|
||||||
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||||
|
if embedding_dim % 2 == 1:
|
||||||
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
|
self.downscale_freq_shift = downscale_freq_shift
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||||
|
return get_timestep_embedding(
|
||||||
|
timesteps,
|
||||||
|
self.num_channels,
|
||||||
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||||
|
downscale_freq_shift=self.downscale_freq_shift,
|
||||||
|
scale=self.scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
time_embed_dim: int,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
out_dim: int = None,
|
||||||
|
post_act_fn: Optional[str] = None,
|
||||||
|
cond_proj_dim=None,
|
||||||
|
sample_proj_bias=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
||||||
|
if cond_proj_dim is not None:
|
||||||
|
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||||
|
else:
|
||||||
|
self.cond_proj = None
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim
|
||||||
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
||||||
|
self.post_act = nn.SiLU() if post_act_fn == "silu" else None
|
||||||
|
|
||||||
|
def forward(self, sample, condition=None):
|
||||||
|
if condition is not None:
|
||||||
|
sample = sample + self.cond_proj(condition)
|
||||||
|
sample = self.linear_1(sample)
|
||||||
|
if self.act is not None:
|
||||||
|
sample = self.act(sample)
|
||||||
|
sample = self.linear_2(sample)
|
||||||
|
if self.post_act is not None:
|
||||||
|
sample = self.post_act(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class PixArtAlphaTextProjection(nn.Module):
|
||||||
|
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
|
||||||
|
super().__init__()
|
||||||
|
if out_features is None:
|
||||||
|
out_features = hidden_size
|
||||||
|
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
||||||
|
if act_fn == "gelu_tanh":
|
||||||
|
self.act_1 = nn.GELU(approximate="tanh")
|
||||||
|
elif act_fn == "silu":
|
||||||
|
self.act_1 = nn.SiLU()
|
||||||
|
else:
|
||||||
|
self.act_1 = nn.GELU(approximate="tanh")
|
||||||
|
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
|
||||||
|
|
||||||
|
def forward(self, caption):
|
||||||
|
hidden_states = self.linear_1(caption)
|
||||||
|
hidden_states = self.act_1(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||||
|
self.approximate = approximate
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.proj(hidden_states)
|
||||||
|
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
dim_out: Optional[int] = None,
|
||||||
|
mult: int = 4,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
activation_fn: str = "geglu",
|
||||||
|
final_dropout: bool = False,
|
||||||
|
inner_dim=None,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if inner_dim is None:
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
dim_out = dim_out if dim_out is not None else dim
|
||||||
|
|
||||||
|
# Build activation + projection matching diffusers pattern
|
||||||
|
if activation_fn == "gelu":
|
||||||
|
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||||
|
elif activation_fn == "gelu-approximate":
|
||||||
|
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
||||||
|
else:
|
||||||
|
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||||
|
|
||||||
|
self.net = nn.ModuleList([])
|
||||||
|
self.net.append(act_fn)
|
||||||
|
self.net.append(nn.Dropout(dropout))
|
||||||
|
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
|
||||||
|
if final_dropout:
|
||||||
|
self.net.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||||
|
for module in self.net:
|
||||||
|
hidden_states = module(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def _to_tuple(x, dim=2):
|
||||||
|
if isinstance(x, int):
|
||||||
|
return (x,) * dim
|
||||||
|
elif len(x) == dim:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_meshgrid_nd(start, *args, dim=2):
|
||||||
|
if len(args) == 0:
|
||||||
|
num = _to_tuple(start, dim=dim)
|
||||||
|
start = (0,) * dim
|
||||||
|
stop = num
|
||||||
|
elif len(args) == 1:
|
||||||
|
start = _to_tuple(start, dim=dim)
|
||||||
|
stop = _to_tuple(args[0], dim=dim)
|
||||||
|
num = [stop[i] - start[i] for i in range(dim)]
|
||||||
|
elif len(args) == 2:
|
||||||
|
start = _to_tuple(start, dim=dim)
|
||||||
|
stop = _to_tuple(args[0], dim=dim)
|
||||||
|
num = _to_tuple(args[1], dim=dim)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||||
|
axis_grid = []
|
||||||
|
for i in range(dim):
|
||||||
|
a, b, n = start[i], stop[i], num[i]
|
||||||
|
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||||
|
axis_grid.append(g)
|
||||||
|
grid = torch.meshgrid(*axis_grid, indexing="ij")
|
||||||
|
grid = torch.stack(grid, dim=0)
|
||||||
|
return grid
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_for_broadcast(freqs_cis, x, head_first=False):
|
||||||
|
ndim = x.ndim
|
||||||
|
assert 0 <= 1 < ndim
|
||||||
|
if isinstance(freqs_cis, tuple):
|
||||||
|
if head_first:
|
||||||
|
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1])
|
||||||
|
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
else:
|
||||||
|
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1])
|
||||||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||||
|
else:
|
||||||
|
if head_first:
|
||||||
|
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
||||||
|
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
else:
|
||||||
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
return freqs_cis.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||||
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(xq, xk, freqs_cis, head_first=False):
|
||||||
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
|
||||||
|
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||||
|
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||||
|
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||||
|
return xq_out, xk_out
|
||||||
|
|
||||||
|
|
||||||
|
def get_1d_rotary_pos_embed(dim, pos, theta=10000.0, use_real=False, theta_rescale_factor=1.0, interpolation_factor=1.0):
|
||||||
|
if isinstance(pos, int):
|
||||||
|
pos = torch.arange(pos).float()
|
||||||
|
if theta_rescale_factor != 1.0:
|
||||||
|
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
|
freqs = torch.outer(pos * interpolation_factor, freqs)
|
||||||
|
if use_real:
|
||||||
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1)
|
||||||
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1)
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
|
else:
|
||||||
|
return torch.polar(torch.ones_like(freqs), freqs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000.0, use_real=False,
|
||||||
|
txt_rope_size=None, theta_rescale_factor=1.0, interpolation_factor=1.0):
|
||||||
|
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))
|
||||||
|
if isinstance(theta_rescale_factor, (int, float)):
|
||||||
|
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
||||||
|
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
||||||
|
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
||||||
|
if isinstance(interpolation_factor, (int, float)):
|
||||||
|
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
||||||
|
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
||||||
|
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
||||||
|
embs = []
|
||||||
|
for i in range(len(rope_dim_list)):
|
||||||
|
emb = get_1d_rotary_pos_embed(
|
||||||
|
rope_dim_list[i], grid[i].reshape(-1), theta,
|
||||||
|
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
|
||||||
|
interpolation_factor=interpolation_factor[i],
|
||||||
|
)
|
||||||
|
embs.append(emb)
|
||||||
|
if use_real:
|
||||||
|
vis_emb = (torch.cat([emb[0] for emb in embs], dim=1), torch.cat([emb[1] for emb in embs], dim=1))
|
||||||
|
else:
|
||||||
|
vis_emb = torch.cat(embs, dim=1)
|
||||||
|
if txt_rope_size is not None:
|
||||||
|
embs_txt = []
|
||||||
|
vis_max_ids = grid.view(-1).max().item()
|
||||||
|
grid_txt = torch.arange(txt_rope_size) + vis_max_ids + 1
|
||||||
|
for i in range(len(rope_dim_list)):
|
||||||
|
emb = get_1d_rotary_pos_embed(
|
||||||
|
rope_dim_list[i], grid_txt, theta,
|
||||||
|
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
|
||||||
|
interpolation_factor=interpolation_factor[i],
|
||||||
|
)
|
||||||
|
embs_txt.append(emb)
|
||||||
|
if use_real:
|
||||||
|
txt_emb = (torch.cat([emb[0] for emb in embs_txt], dim=1), torch.cat([emb[1] for emb in embs_txt], dim=1))
|
||||||
|
else:
|
||||||
|
txt_emb = torch.cat(embs_txt, dim=1)
|
||||||
|
else:
|
||||||
|
txt_emb = None
|
||||||
|
return vis_emb, txt_emb
|
||||||
|
|
||||||
|
|
||||||
|
class ModulateWan(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.factor = factor
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
self.modulate_table = nn.Parameter(
|
||||||
|
torch.zeros(1, factor, hidden_size, **factory_kwargs) / hidden_size**0.5,
|
||||||
|
requires_grad=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if len(x.shape) != 3:
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)]
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift=None, scale=None):
|
||||||
|
if scale is None and shift is None:
|
||||||
|
return x
|
||||||
|
elif shift is None:
|
||||||
|
return x * (1 + scale.unsqueeze(1))
|
||||||
|
elif scale is None:
|
||||||
|
return x + shift.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_gate(x, gate=None, tanh=False):
|
||||||
|
if gate is None:
|
||||||
|
return x
|
||||||
|
if tanh:
|
||||||
|
return x * gate.unsqueeze(1).tanh()
|
||||||
|
else:
|
||||||
|
return x * gate.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def load_modulation(modulate_type: str, hidden_size: int, factor: int, dtype=None, device=None):
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
if modulate_type == 'wanx':
|
||||||
|
return ModulateWan(hidden_size, factor, **factory_kwargs)
|
||||||
|
raise ValueError(f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.")
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, device=None, dtype=None):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
if hasattr(self, "weight"):
|
||||||
|
output = output * self.weight
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class MMDoubleStreamBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A multimodal dit block with separate modulation for
|
||||||
|
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
|
||||||
|
(Flux.1): https://github.com/black-forest-labs/flux
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
heads_num: int,
|
||||||
|
mlp_width_ratio: float,
|
||||||
|
mlp_act_type: str = "gelu_tanh",
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dit_modulation_type: Optional[str] = "wanx",
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.dit_modulation_type = dit_modulation_type
|
||||||
|
self.heads_num = heads_num
|
||||||
|
head_dim = hidden_size // heads_num
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||||
|
|
||||||
|
self.img_mod = load_modulation(
|
||||||
|
modulate_type=self.dit_modulation_type,
|
||||||
|
hidden_size=hidden_size, factor=6, **factory_kwargs,
|
||||||
|
)
|
||||||
|
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
|
||||||
|
self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
|
||||||
|
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
|
||||||
|
|
||||||
|
self.txt_mod = load_modulation(
|
||||||
|
modulate_type=self.dit_modulation_type,
|
||||||
|
hidden_size=hidden_size, factor=6, **factory_kwargs,
|
||||||
|
)
|
||||||
|
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
|
||||||
|
self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
|
||||||
|
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
img: torch.Tensor,
|
||||||
|
txt: torch.Tensor,
|
||||||
|
vec: torch.Tensor,
|
||||||
|
vis_freqs_cis: tuple = None,
|
||||||
|
txt_freqs_cis: tuple = None,
|
||||||
|
attn_kwargs: Optional[dict] = {},
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
(
|
||||||
|
img_mod1_shift, img_mod1_scale, img_mod1_gate,
|
||||||
|
img_mod2_shift, img_mod2_scale, img_mod2_gate,
|
||||||
|
) = self.img_mod(vec)
|
||||||
|
(
|
||||||
|
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate,
|
||||||
|
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate,
|
||||||
|
) = self.txt_mod(vec)
|
||||||
|
|
||||||
|
img_modulated = self.img_norm1(img)
|
||||||
|
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
|
||||||
|
img_qkv = self.img_attn_qkv(img_modulated)
|
||||||
|
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
|
img_q = self.img_attn_q_norm(img_q).to(img_v)
|
||||||
|
img_k = self.img_attn_k_norm(img_k).to(img_v)
|
||||||
|
|
||||||
|
if vis_freqs_cis is not None:
|
||||||
|
img_qq, img_kk = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
|
||||||
|
img_q, img_k = img_qq, img_kk
|
||||||
|
|
||||||
|
txt_modulated = self.txt_norm1(txt)
|
||||||
|
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
|
||||||
|
txt_qkv = self.txt_attn_qkv(txt_modulated)
|
||||||
|
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
|
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
|
||||||
|
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
|
||||||
|
|
||||||
|
if txt_freqs_cis is not None:
|
||||||
|
raise NotImplementedError("RoPE text is not supported for inference")
|
||||||
|
|
||||||
|
q = torch.cat((img_q, txt_q), dim=1)
|
||||||
|
k = torch.cat((img_k, txt_k), dim=1)
|
||||||
|
v = torch.cat((img_v, txt_v), dim=1)
|
||||||
|
|
||||||
|
# Use DiffSynth unified attention
|
||||||
|
attn_out = attention_forward(
|
||||||
|
q, k, v,
|
||||||
|
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_out = attn_out.flatten(2, 3)
|
||||||
|
img_attn, txt_attn = attn_out[:, : img.shape[1]], attn_out[:, img.shape[1]:]
|
||||||
|
|
||||||
|
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
|
||||||
|
img = img + apply_gate(
|
||||||
|
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
|
||||||
|
gate=img_mod2_gate,
|
||||||
|
)
|
||||||
|
|
||||||
|
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
|
||||||
|
txt = txt + apply_gate(
|
||||||
|
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
|
||||||
|
gate=txt_mod2_gate,
|
||||||
|
)
|
||||||
|
|
||||||
|
return img, txt
|
||||||
|
|
||||||
|
|
||||||
|
class WanTimeTextImageEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
time_freq_dim: int,
|
||||||
|
time_proj_dim: int,
|
||||||
|
text_embed_dim: int,
|
||||||
|
image_embed_dim: Optional[int] = None,
|
||||||
|
pos_embed_seq_len: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||||
|
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
||||||
|
self.act_fn = nn.SiLU()
|
||||||
|
self.time_proj = nn.Linear(dim, time_proj_dim)
|
||||||
|
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
||||||
|
|
||||||
|
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
|
||||||
|
timestep = self.timesteps_proj(timestep)
|
||||||
|
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||||
|
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||||
|
timestep = timestep.to(time_embedder_dtype)
|
||||||
|
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||||
|
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||||
|
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||||
|
return temb, timestep_proj, encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageDiT(nn.Module):
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: list = [1, 2, 2],
|
||||||
|
in_channels: int = 16,
|
||||||
|
out_channels: int = 16,
|
||||||
|
hidden_size: int = 4096,
|
||||||
|
heads_num: int = 32,
|
||||||
|
text_states_dim: int = 4096,
|
||||||
|
mlp_width_ratio: float = 4.0,
|
||||||
|
mm_double_blocks_depth: int = 40,
|
||||||
|
rope_dim_list: List[int] = [16, 56, 56],
|
||||||
|
rope_type: str = 'rope',
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dit_modulation_type: str = "wanx",
|
||||||
|
theta: int = 10000,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.out_channels = out_channels or in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.heads_num = heads_num
|
||||||
|
self.rope_dim_list = rope_dim_list
|
||||||
|
self.dit_modulation_type = dit_modulation_type
|
||||||
|
self.mm_double_blocks_depth = mm_double_blocks_depth
|
||||||
|
self.rope_type = rope_type
|
||||||
|
self.theta = theta
|
||||||
|
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
if hidden_size % heads_num != 0:
|
||||||
|
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
|
||||||
|
|
||||||
|
self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
|
self.condition_embedder = WanTimeTextImageEmbedding(
|
||||||
|
dim=hidden_size,
|
||||||
|
time_freq_dim=256,
|
||||||
|
time_proj_dim=hidden_size * 6,
|
||||||
|
text_embed_dim=text_states_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.double_blocks = nn.ModuleList([
|
||||||
|
MMDoubleStreamBlock(
|
||||||
|
self.hidden_size, self.heads_num,
|
||||||
|
mlp_width_ratio=mlp_width_ratio,
|
||||||
|
dit_modulation_type=self.dit_modulation_type,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
for _ in range(mm_double_blocks_depth)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size), **factory_kwargs)
|
||||||
|
|
||||||
|
def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None):
|
||||||
|
target_ndim = 3
|
||||||
|
if len(vis_rope_size) != target_ndim:
|
||||||
|
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
|
||||||
|
head_dim = self.hidden_size // self.heads_num
|
||||||
|
rope_dim_list = self.rope_dim_list
|
||||||
|
if rope_dim_list is None:
|
||||||
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
||||||
|
assert sum(rope_dim_list) == head_dim
|
||||||
|
vis_freqs, txt_freqs = get_nd_rotary_pos_embed(
|
||||||
|
rope_dim_list, vis_rope_size,
|
||||||
|
txt_rope_size=txt_rope_size if self.rope_type == 'mrope' else None,
|
||||||
|
theta=self.theta, use_real=True, theta_rescale_factor=1,
|
||||||
|
)
|
||||||
|
return vis_freqs, txt_freqs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor = None,
|
||||||
|
encoder_hidden_states_mask: torch.Tensor = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
|
is_multi_item = (len(hidden_states.shape) == 6)
|
||||||
|
num_items = 0
|
||||||
|
if is_multi_item:
|
||||||
|
num_items = hidden_states.shape[1]
|
||||||
|
if num_items > 1:
|
||||||
|
assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1"
|
||||||
|
hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1)
|
||||||
|
hidden_states = rearrange(hidden_states, 'b n c t h w -> b c (n t) h w')
|
||||||
|
|
||||||
|
batch_size, _, ot, oh, ow = hidden_states.shape
|
||||||
|
tt, th, tw = ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2]
|
||||||
|
|
||||||
|
if encoder_hidden_states_mask is None:
|
||||||
|
encoder_hidden_states_mask = torch.ones(
|
||||||
|
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]),
|
||||||
|
dtype=torch.bool,
|
||||||
|
).to(encoder_hidden_states.device)
|
||||||
|
|
||||||
|
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
|
||||||
|
temb, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
|
||||||
|
if vec.shape[-1] > self.hidden_size:
|
||||||
|
vec = vec.unflatten(1, (6, -1))
|
||||||
|
|
||||||
|
txt_seq_len = txt.shape[1]
|
||||||
|
img_seq_len = img.shape[1]
|
||||||
|
|
||||||
|
vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed(
|
||||||
|
vis_rope_size=(tt, th, tw),
|
||||||
|
txt_rope_size=txt_seq_len if self.rope_type == 'mrope' else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
for block in self.double_blocks:
|
||||||
|
img, txt = gradient_checkpoint_forward(
|
||||||
|
block,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
img=img, txt=txt, vec=vec,
|
||||||
|
vis_freqs_cis=vis_freqs_cis, txt_freqs_cis=txt_freqs_cis,
|
||||||
|
attn_kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
img_len = img.shape[1]
|
||||||
|
x = torch.cat((img, txt), 1)
|
||||||
|
img = x[:, :img_len, ...]
|
||||||
|
|
||||||
|
img = self.proj_out(self.norm_out(img))
|
||||||
|
img = self.unpatchify(img, tt, th, tw)
|
||||||
|
|
||||||
|
if is_multi_item:
|
||||||
|
img = rearrange(img, 'b c (n t) h w -> b n c t h w', n=num_items)
|
||||||
|
if num_items > 1:
|
||||||
|
img = torch.cat([img[:, 1:], img[:, :1]], dim=1)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
def unpatchify(self, x, t, h, w):
|
||||||
|
c = self.out_channels
|
||||||
|
pt, ph, pw = self.patch_size
|
||||||
|
assert t * h * w == x.shape[1]
|
||||||
|
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
|
||||||
|
x = torch.einsum("nthwopqc->nctohpwq", x)
|
||||||
|
return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
||||||
82
diffsynth/models/joyai_image_text_encoder.py
Normal file
82
diffsynth/models/joyai_image_text_encoder.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageTextEncoder(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
|
||||||
|
|
||||||
|
config = Qwen3VLConfig(
|
||||||
|
text_config={
|
||||||
|
"attention_bias": False,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 4096,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 12288,
|
||||||
|
"max_position_embeddings": 262144,
|
||||||
|
"model_type": "qwen3_vl_text",
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 36,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"rms_norm_eps": 1e-6,
|
||||||
|
"rope_scaling": {
|
||||||
|
"mrope_interleaved": True,
|
||||||
|
"mrope_section": [24, 20, 20],
|
||||||
|
"rope_type": "default",
|
||||||
|
},
|
||||||
|
"rope_theta": 5000000,
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 151936,
|
||||||
|
},
|
||||||
|
vision_config={
|
||||||
|
"deepstack_visual_indexes": [8, 16, 24],
|
||||||
|
"depth": 27,
|
||||||
|
"hidden_act": "gelu_pytorch_tanh",
|
||||||
|
"hidden_size": 1152,
|
||||||
|
"in_channels": 3,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 4304,
|
||||||
|
"model_type": "qwen3_vl",
|
||||||
|
"num_heads": 16,
|
||||||
|
"num_position_embeddings": 2304,
|
||||||
|
"out_hidden_size": 4096,
|
||||||
|
"patch_size": 16,
|
||||||
|
"spatial_merge_size": 2,
|
||||||
|
"temporal_patch_size": 2,
|
||||||
|
},
|
||||||
|
image_token_id=151655,
|
||||||
|
video_token_id=151656,
|
||||||
|
vision_start_token_id=151652,
|
||||||
|
vision_end_token_id=151653,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = Qwen3VLForConditionalGeneration(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
pre_norm_output = [None]
|
||||||
|
def hook_fn(module, args, kwargs_output=None):
|
||||||
|
pre_norm_output[0] = args[0]
|
||||||
|
self.model.model.language_model.norm.register_forward_hook(hook_fn)
|
||||||
|
_ = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return pre_norm_output[0]
|
||||||
78
diffsynth/models/stable_diffusion_text_encoder.py
Normal file
78
diffsynth/models/stable_diffusion_text_encoder.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SDTextEncoder(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=768,
|
||||||
|
intermediate_size=3072,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
max_position_embeddings=77,
|
||||||
|
vocab_size=49408,
|
||||||
|
layer_norm_eps=1e-05,
|
||||||
|
hidden_act="quick_gelu",
|
||||||
|
initializer_factor=1.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=1,
|
||||||
|
projection_dim=768,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import CLIPConfig, CLIPTextModel
|
||||||
|
|
||||||
|
config = CLIPConfig(
|
||||||
|
text_config={
|
||||||
|
"hidden_size": hidden_size,
|
||||||
|
"intermediate_size": intermediate_size,
|
||||||
|
"num_hidden_layers": num_hidden_layers,
|
||||||
|
"num_attention_heads": num_attention_heads,
|
||||||
|
"max_position_embeddings": max_position_embeddings,
|
||||||
|
"vocab_size": vocab_size,
|
||||||
|
"layer_norm_eps": layer_norm_eps,
|
||||||
|
"hidden_act": hidden_act,
|
||||||
|
"initializer_factor": initializer_factor,
|
||||||
|
"initializer_range": initializer_range,
|
||||||
|
"bos_token_id": bos_token_id,
|
||||||
|
"eos_token_id": eos_token_id,
|
||||||
|
"pad_token_id": pad_token_id,
|
||||||
|
"projection_dim": projection_dim,
|
||||||
|
"dropout": 0.0,
|
||||||
|
},
|
||||||
|
vision_config={
|
||||||
|
"hidden_size": hidden_size,
|
||||||
|
"intermediate_size": intermediate_size,
|
||||||
|
"num_hidden_layers": num_hidden_layers,
|
||||||
|
"num_attention_heads": num_attention_heads,
|
||||||
|
"max_position_embeddings": max_position_embeddings,
|
||||||
|
"layer_norm_eps": layer_norm_eps,
|
||||||
|
"hidden_act": hidden_act,
|
||||||
|
"initializer_factor": initializer_factor,
|
||||||
|
"initializer_range": initializer_range,
|
||||||
|
"projection_dim": projection_dim,
|
||||||
|
},
|
||||||
|
projection_dim=projection_dim,
|
||||||
|
)
|
||||||
|
self.model = CLIPTextModel(config.text_config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
output_hidden_states=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if output_hidden_states:
|
||||||
|
return outputs.last_hidden_state, outputs.hidden_states
|
||||||
|
return outputs.last_hidden_state
|
||||||
912
diffsynth/models/stable_diffusion_unet.py
Normal file
912
diffsynth/models/stable_diffusion_unet.py
Normal file
@@ -0,0 +1,912 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Time Embedding =====
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
|
self.freq_shift = freq_shift
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
half_dim = self.num_channels // 2
|
||||||
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||||
|
exponent = exponent / half_dim + self.freq_shift
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
sin_emb = torch.sin(emb)
|
||||||
|
cos_emb = torch.cos(emb)
|
||||||
|
if self.flip_sin_to_cos:
|
||||||
|
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||||
|
else:
|
||||||
|
emb = torch.cat([sin_emb, cos_emb], dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||||
|
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
|
||||||
|
out_dim = out_dim if out_dim is not None else time_embed_dim
|
||||||
|
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
|
||||||
|
|
||||||
|
def forward(self, sample):
|
||||||
|
sample = self.linear_1(sample)
|
||||||
|
sample = self.act(sample)
|
||||||
|
sample = self.linear_2(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
# ===== ResNet Blocks =====
|
||||||
|
|
||||||
|
class ResnetBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout=0.0,
|
||||||
|
temb_channels=512,
|
||||||
|
groups=32,
|
||||||
|
groups_out=None,
|
||||||
|
pre_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
non_linearity="swish",
|
||||||
|
time_embedding_norm="default",
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
use_in_shortcut=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.pre_norm = pre_norm
|
||||||
|
self.time_embedding_norm = time_embedding_norm
|
||||||
|
self.output_scale_factor = output_scale_factor
|
||||||
|
|
||||||
|
if groups_out is None:
|
||||||
|
groups_out = groups
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
if temb_channels is not None:
|
||||||
|
if self.time_embedding_norm == "default":
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
|
||||||
|
elif self.time_embedding_norm == "scale_shift":
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
|
||||||
|
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
if non_linearity == "swish":
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
elif non_linearity == "silu":
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
elif non_linearity == "gelu":
|
||||||
|
self.nonlinearity = nn.GELU()
|
||||||
|
elif non_linearity == "relu":
|
||||||
|
self.nonlinearity = nn.ReLU()
|
||||||
|
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
self.conv_shortcut = None
|
||||||
|
if conv_shortcut:
|
||||||
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
else:
|
||||||
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
|
||||||
|
|
||||||
|
def forward(self, input_tensor, temb=None):
|
||||||
|
hidden_states = input_tensor
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.conv1(hidden_states)
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
temb = self.nonlinearity(temb)
|
||||||
|
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
if temb is not None and self.time_embedding_norm == "default":
|
||||||
|
hidden_states = hidden_states + temb
|
||||||
|
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
|
||||||
|
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||||
|
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||||
|
hidden_states = hidden_states * (1 + scale) + shift
|
||||||
|
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.conv2(hidden_states)
|
||||||
|
|
||||||
|
if self.conv_shortcut is not None:
|
||||||
|
input_tensor = self.conv_shortcut(input_tensor)
|
||||||
|
|
||||||
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Transformer Blocks =====
|
||||||
|
|
||||||
|
class GEGLU(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||||
|
return hidden_states * F.gelu(gate)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, dim_out=None, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.ModuleList([
|
||||||
|
GEGLU(dim, dim * 4),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
for module in self.net:
|
||||||
|
hidden_states = module(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""Attention block matching diffusers checkpoint key format.
|
||||||
|
Keys: to_q.weight, to_k.weight, to_v.weight, to_out.0.weight, to_out.0.bias
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=False,
|
||||||
|
upcast_attention=False,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.heads = heads
|
||||||
|
self.inner_dim = inner_dim
|
||||||
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||||
|
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
||||||
|
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
||||||
|
self.to_out = nn.ModuleList([
|
||||||
|
nn.Linear(inner_dim, query_dim, bias=True),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
# Query
|
||||||
|
query = self.to_q(hidden_states)
|
||||||
|
batch_size, seq_len, _ = query.shape
|
||||||
|
|
||||||
|
# Key/Value
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
key = self.to_k(encoder_hidden_states)
|
||||||
|
value = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
# Reshape for multi-head attention
|
||||||
|
head_dim = self.inner_dim // self.heads
|
||||||
|
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Scaled dot-product attention
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape back
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# Output projection
|
||||||
|
hidden_states = self.to_out[0](hidden_states)
|
||||||
|
hidden_states = self.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
dropout=0.0,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
upcast_attention=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.attn1 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
bias=False,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
self.attn2 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
bias=False,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
|
self.ff = FeedForward(dim, dropout=dropout)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
# Self-attention
|
||||||
|
attn_output = self.attn1(self.norm1(hidden_states))
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
# Cross-attention
|
||||||
|
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
# Feed-forward
|
||||||
|
ff_output = self.ff(self.norm3(hidden_states))
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer2DModel(nn.Module):
|
||||||
|
"""2D Transformer block wrapper matching diffusers checkpoint structure.
|
||||||
|
Keys: norm.weight/bias, proj_in.weight/bias, transformer_blocks.X.*, proj_out.weight/bias
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_attention_heads=16,
|
||||||
|
attention_head_dim=64,
|
||||||
|
in_channels=320,
|
||||||
|
num_layers=1,
|
||||||
|
dropout=0.0,
|
||||||
|
norm_num_groups=32,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
upcast_attention=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
|
||||||
|
self.proj_in = nn.Conv2d(in_channels, num_attention_heads * attention_head_dim, kernel_size=1, bias=True)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList([
|
||||||
|
BasicTransformerBlock(
|
||||||
|
dim=num_attention_heads * attention_head_dim,
|
||||||
|
n_heads=num_attention_heads,
|
||||||
|
d_head=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.proj_out = nn.Conv2d(num_attention_heads * attention_head_dim, in_channels, kernel_size=1, bias=True)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
batch, channel, height, width = hidden_states.shape
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
# Normalize and project to sequence
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
hidden_states = self.proj_in(hidden_states)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
|
||||||
|
|
||||||
|
# Transformer blocks
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
# Project back to 2D
|
||||||
|
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Down/Up Blocks =====
|
||||||
|
|
||||||
|
class CrossAttnDownBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=1,
|
||||||
|
downsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = True
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels_i = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels_i,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
attentions.append(
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=attention_head_dim,
|
||||||
|
attention_head_dim=out_channels // attention_head_dim,
|
||||||
|
in_channels=out_channels,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=resnet_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if downsample:
|
||||||
|
self.downsamplers = nn.ModuleList([
|
||||||
|
Downsample2D(out_channels, out_channels, padding=1)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.downsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
output_states = []
|
||||||
|
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, tuple(output_states)
|
||||||
|
|
||||||
|
|
||||||
|
class DownBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
downsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = False
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels_i = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels_i,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if downsample:
|
||||||
|
self.downsamplers = nn.ModuleList([
|
||||||
|
Downsample2D(out_channels, out_channels, padding=1)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.downsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
output_states = []
|
||||||
|
for resnet in self.resnets:
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, tuple(output_states)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttnUpBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
prev_output_channel,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=1,
|
||||||
|
upsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = True
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||||
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||||
|
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=resnet_in_channels + res_skip_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
attentions.append(
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=attention_head_dim,
|
||||||
|
attention_head_dim=out_channels // attention_head_dim,
|
||||||
|
in_channels=out_channels,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=resnet_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if upsample:
|
||||||
|
self.upsamplers = nn.ModuleList([
|
||||||
|
Upsample2D(out_channels, out_channels)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.upsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
# Pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class UpBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
prev_output_channel,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
upsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = False
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||||
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||||
|
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=resnet_in_channels + res_skip_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if upsample:
|
||||||
|
self.upsamplers = nn.ModuleList([
|
||||||
|
Upsample2D(out_channels, out_channels)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.upsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
||||||
|
for resnet in self.resnets:
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# ===== UNet Mid Block =====
|
||||||
|
|
||||||
|
class UNetMidBlock2DCrossAttn(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||||
|
|
||||||
|
# There is always at least one resnet
|
||||||
|
resnets = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for _ in range(num_layers):
|
||||||
|
attentions.append(
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=attention_head_dim,
|
||||||
|
attention_head_dim=in_channels // attention_head_dim,
|
||||||
|
in_channels=in_channels,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=resnet_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
hidden_states = self.resnets[0](hidden_states, temb)
|
||||||
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Downsample / Upsample =====
|
||||||
|
|
||||||
|
class Downsample2D(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, padding=1):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if self.padding == 0:
|
||||||
|
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample2D(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, upsample_size=None):
|
||||||
|
if upsample_size is not None:
|
||||||
|
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
|
||||||
|
else:
|
||||||
|
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== UNet2DConditionModel =====
|
||||||
|
|
||||||
|
class UNet2DConditionModel(nn.Module):
|
||||||
|
"""Stable Diffusion UNet with cross-attention conditioning.
|
||||||
|
state_dict keys match the diffusers UNet2DConditionModel checkpoint format.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_size=64,
|
||||||
|
in_channels=4,
|
||||||
|
out_channels=4,
|
||||||
|
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
|
||||||
|
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||||
|
block_out_channels=(320, 640, 1280, 1280),
|
||||||
|
layers_per_block=2,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=8,
|
||||||
|
norm_num_groups=32,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
dropout=0.0,
|
||||||
|
act_fn="silu",
|
||||||
|
time_embedding_type="positional",
|
||||||
|
flip_sin_to_cos=True,
|
||||||
|
freq_shift=0,
|
||||||
|
time_embedding_dim=None,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
upcast_attention=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.sample_size = sample_size
|
||||||
|
|
||||||
|
# Time embedding
|
||||||
|
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
|
||||||
|
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
|
||||||
|
time_embed_dim = block_out_channels[0] * 4
|
||||||
|
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
|
||||||
|
|
||||||
|
# Input
|
||||||
|
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
# Down blocks
|
||||||
|
self.down_blocks = nn.ModuleList()
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
for i, down_block_type in enumerate(down_block_types):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
if "CrossAttn" in down_block_type:
|
||||||
|
down_block = CrossAttnDownBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
downsample=not is_final_block,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
down_block = DownBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
downsample=not is_final_block,
|
||||||
|
)
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
# Mid block
|
||||||
|
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Up blocks
|
||||||
|
self.up_blocks = nn.ModuleList()
|
||||||
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||||
|
output_channel = reversed_block_out_channels[0]
|
||||||
|
|
||||||
|
for i, up_block_type in enumerate(up_block_types):
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
output_channel = reversed_block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
# in_channels for up blocks: diffusers uses reversed_block_out_channels[min(i+1, len-1)]
|
||||||
|
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||||
|
|
||||||
|
if "CrossAttn" in up_block_type:
|
||||||
|
up_block = CrossAttnUpBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block + 1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
upsample=not is_final_block,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
up_block = UpBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block + 1,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
upsample=not is_final_block,
|
||||||
|
)
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
|
||||||
|
# Output
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
|
||||||
|
# 1. Time embedding
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||||
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
t_emb = t_emb.to(dtype=sample.dtype)
|
||||||
|
emb = self.time_embedding(t_emb)
|
||||||
|
|
||||||
|
# 2. Pre-process
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
# 3. Down
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for down_block in self.down_blocks:
|
||||||
|
sample, res_samples = down_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
# 4. Mid
|
||||||
|
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
# 5. Up
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
res_samples = down_block_res_samples[-len(up_block.resnets):]
|
||||||
|
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
|
||||||
|
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
|
||||||
|
sample = up_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Post-process
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sample,)
|
||||||
|
return sample
|
||||||
642
diffsynth/models/stable_diffusion_vae.py
Normal file
642
diffsynth/models/stable_diffusion_vae.py
Normal file
@@ -0,0 +1,642 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class DiagonalGaussianDistribution:
|
||||||
|
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||||
|
self.parameters = parameters
|
||||||
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||||
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||||
|
self.deterministic = deterministic
|
||||||
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
|
self.var = torch.exp(self.logvar)
|
||||||
|
if self.deterministic:
|
||||||
|
self.var = self.std = torch.zeros_like(
|
||||||
|
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
||||||
|
# randn_like doesn't accept generator on all torch versions
|
||||||
|
sample = torch.randn(self.mean.shape, generator=generator,
|
||||||
|
device=self.parameters.device, dtype=self.parameters.dtype)
|
||||||
|
return self.mean + self.std * sample
|
||||||
|
|
||||||
|
def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.tensor([0.0])
|
||||||
|
if other is None:
|
||||||
|
return 0.5 * torch.sum(
|
||||||
|
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||||
|
dim=[1, 2, 3],
|
||||||
|
)
|
||||||
|
return 0.5 * torch.sum(
|
||||||
|
torch.pow(self.mean - other.mean, 2) / other.var
|
||||||
|
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||||
|
dim=[1, 2, 3],
|
||||||
|
)
|
||||||
|
|
||||||
|
def mode(self) -> torch.Tensor:
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout=0.0,
|
||||||
|
temb_channels=512,
|
||||||
|
groups=32,
|
||||||
|
groups_out=None,
|
||||||
|
pre_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
non_linearity="swish",
|
||||||
|
time_embedding_norm="default",
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
use_in_shortcut=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.pre_norm = pre_norm
|
||||||
|
self.time_embedding_norm = time_embedding_norm
|
||||||
|
self.output_scale_factor = output_scale_factor
|
||||||
|
|
||||||
|
if groups_out is None:
|
||||||
|
groups_out = groups
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
if temb_channels is not None:
|
||||||
|
if self.time_embedding_norm == "default":
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
|
||||||
|
elif self.time_embedding_norm == "scale_shift":
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
|
||||||
|
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
if non_linearity == "swish":
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
elif non_linearity == "silu":
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
elif non_linearity == "gelu":
|
||||||
|
self.nonlinearity = nn.GELU()
|
||||||
|
elif non_linearity == "relu":
|
||||||
|
self.nonlinearity = nn.ReLU()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported non_linearity: {non_linearity}")
|
||||||
|
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
self.conv_shortcut = None
|
||||||
|
if conv_shortcut:
|
||||||
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
else:
|
||||||
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
|
||||||
|
|
||||||
|
def forward(self, input_tensor, temb=None):
|
||||||
|
hidden_states = input_tensor
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.conv1(hidden_states)
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
temb = self.nonlinearity(temb)
|
||||||
|
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
if temb is not None and self.time_embedding_norm == "default":
|
||||||
|
hidden_states = hidden_states + temb
|
||||||
|
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
|
||||||
|
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||||
|
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||||
|
hidden_states = hidden_states * (1 + scale) + shift
|
||||||
|
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.conv2(hidden_states)
|
||||||
|
|
||||||
|
if self.conv_shortcut is not None:
|
||||||
|
input_tensor = self.conv_shortcut(input_tensor)
|
||||||
|
|
||||||
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class DownEncoderBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
add_downsample=True,
|
||||||
|
downsample_padding=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels_i = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels_i,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=None,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if add_downsample:
|
||||||
|
self.downsamplers = nn.ModuleList([
|
||||||
|
Downsample2D(out_channels, out_channels, padding=downsample_padding)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.downsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, *args, **kwargs):
|
||||||
|
for resnet in self.resnets:
|
||||||
|
hidden_states = resnet(hidden_states, temb=None)
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class UpDecoderBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
add_upsample=True,
|
||||||
|
temb_channels=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels_i = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels_i,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if add_upsample:
|
||||||
|
self.upsamplers = nn.ModuleList([
|
||||||
|
Upsample2D(out_channels, out_channels)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.upsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None):
|
||||||
|
for resnet in self.resnets:
|
||||||
|
hidden_states = resnet(hidden_states, temb=temb)
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class UNetMidBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
temb_channels=None,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
add_attention=True,
|
||||||
|
attention_head_dim=1,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||||
|
self.add_attention = add_attention
|
||||||
|
|
||||||
|
# there is always at least one resnet
|
||||||
|
resnets = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
if attention_head_dim is None:
|
||||||
|
attention_head_dim = in_channels
|
||||||
|
|
||||||
|
for _ in range(num_layers):
|
||||||
|
if self.add_attention:
|
||||||
|
attentions.append(
|
||||||
|
AttentionBlock(
|
||||||
|
in_channels,
|
||||||
|
num_groups=resnet_groups,
|
||||||
|
eps=resnet_eps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attentions.append(None)
|
||||||
|
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None):
|
||||||
|
hidden_states = self.resnets[0](hidden_states, temb)
|
||||||
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||||
|
if attn is not None:
|
||||||
|
hidden_states = attn(hidden_states)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
"""Simple attention block for VAE mid block.
|
||||||
|
Mirrors diffusers Attention class with AttnProcessor2_0 for VAE use case.
|
||||||
|
Uses modern key names (to_q, to_k, to_v, to_out) matching in-memory diffusers structure.
|
||||||
|
Checkpoint uses deprecated keys (query, key, value, proj_attn) — mapped via converter.
|
||||||
|
"""
|
||||||
|
def __init__(self, channels, num_groups=32, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.eps = eps
|
||||||
|
self.heads = 1
|
||||||
|
self.rescale_output_factor = 1.0
|
||||||
|
|
||||||
|
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=eps, affine=True)
|
||||||
|
self.to_q = nn.Linear(channels, channels, bias=True)
|
||||||
|
self.to_k = nn.Linear(channels, channels, bias=True)
|
||||||
|
self.to_v = nn.Linear(channels, channels, bias=True)
|
||||||
|
self.to_out = nn.ModuleList([
|
||||||
|
nn.Linear(channels, channels, bias=True),
|
||||||
|
nn.Dropout(0.0),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
# Group norm
|
||||||
|
hidden_states = self.group_norm(hidden_states)
|
||||||
|
|
||||||
|
# Flatten spatial dims: (B, C, H, W) -> (B, H*W, C)
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
# QKV projection
|
||||||
|
query = self.to_q(hidden_states)
|
||||||
|
key = self.to_k(hidden_states)
|
||||||
|
value = self.to_v(hidden_states)
|
||||||
|
|
||||||
|
# Reshape for attention: (B, seq, dim) -> (B, heads, seq, head_dim)
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // self.heads
|
||||||
|
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Scaled dot-product attention
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape back: (B, heads, seq, head_dim) -> (B, seq, heads*head_dim)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# Output projection + dropout
|
||||||
|
hidden_states = self.to_out[0](hidden_states)
|
||||||
|
hidden_states = self.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
# Reshape back to 4D and add residual
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
# Rescale output factor
|
||||||
|
hidden_states = hidden_states / self.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample2D(nn.Module):
|
||||||
|
"""Downsampling layer matching diffusers Downsample2D with use_conv=True.
|
||||||
|
Key names: conv.weight/bias.
|
||||||
|
When padding=0, applies explicit F.pad before conv to match dimension.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, padding=1):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if self.padding == 0:
|
||||||
|
import torch.nn.functional as F
|
||||||
|
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample2D(nn.Module):
|
||||||
|
"""Upsampling layer with key names matching diffusers checkpoint: conv.weight/bias."""
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=("DownEncoderBlock2D",),
|
||||||
|
block_out_channels=(64,),
|
||||||
|
layers_per_block=2,
|
||||||
|
norm_num_groups=32,
|
||||||
|
act_fn="silu",
|
||||||
|
double_z=True,
|
||||||
|
mid_block_add_attention=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers_per_block = layers_per_block
|
||||||
|
|
||||||
|
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList([])
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
for i, down_block_type in enumerate(down_block_types):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
down_block = DownEncoderBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
num_layers=self.layers_per_block,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
add_downsample=not is_final_block,
|
||||||
|
downsample_padding=0,
|
||||||
|
)
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
# mid
|
||||||
|
self.mid_block = UNetMidBlock2D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
output_scale_factor=1,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
attention_head_dim=block_out_channels[-1],
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
temb_channels=None,
|
||||||
|
add_attention=mid_block_add_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
# out
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
conv_out_channels = 2 * out_channels if double_z else out_channels
|
||||||
|
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, sample):
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
for down_block in self.down_blocks:
|
||||||
|
sample = down_block(sample)
|
||||||
|
sample = self.mid_block(sample)
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
up_block_types=("UpDecoderBlock2D",),
|
||||||
|
block_out_channels=(64,),
|
||||||
|
layers_per_block=2,
|
||||||
|
norm_num_groups=32,
|
||||||
|
act_fn="silu",
|
||||||
|
norm_type="group",
|
||||||
|
mid_block_add_attention=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers_per_block = layers_per_block
|
||||||
|
|
||||||
|
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.up_blocks = nn.ModuleList([])
|
||||||
|
temb_channels = in_channels if norm_type == "spatial" else None
|
||||||
|
|
||||||
|
# mid
|
||||||
|
self.mid_block = UNetMidBlock2D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
output_scale_factor=1,
|
||||||
|
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
||||||
|
attention_head_dim=block_out_channels[-1],
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
add_attention=mid_block_add_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
# up
|
||||||
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||||
|
output_channel = reversed_block_out_channels[0]
|
||||||
|
for i, up_block_type in enumerate(up_block_types):
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
output_channel = reversed_block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
up_block = UpDecoderBlock2D(
|
||||||
|
in_channels=prev_output_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
num_layers=self.layers_per_block + 1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
add_upsample=not is_final_block,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
)
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
|
||||||
|
# out
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, sample, latent_embeds=None):
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
sample = self.mid_block(sample, latent_embeds)
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
sample = up_block(sample, latent_embeds)
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionVAE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
|
||||||
|
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
|
||||||
|
block_out_channels=(128, 256, 512, 512),
|
||||||
|
layers_per_block=2,
|
||||||
|
act_fn="silu",
|
||||||
|
latent_channels=4,
|
||||||
|
norm_num_groups=32,
|
||||||
|
sample_size=512,
|
||||||
|
scaling_factor=0.18215,
|
||||||
|
shift_factor=None,
|
||||||
|
latents_mean=None,
|
||||||
|
latents_std=None,
|
||||||
|
force_upcast=True,
|
||||||
|
use_quant_conv=True,
|
||||||
|
use_post_quant_conv=True,
|
||||||
|
mid_block_add_attention=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = Encoder(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=latent_channels,
|
||||||
|
down_block_types=down_block_types,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
act_fn=act_fn,
|
||||||
|
double_z=True,
|
||||||
|
mid_block_add_attention=mid_block_add_attention,
|
||||||
|
)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
in_channels=latent_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
up_block_types=up_block_types,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
act_fn=act_fn,
|
||||||
|
mid_block_add_attention=mid_block_add_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
||||||
|
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
||||||
|
|
||||||
|
self.latents_mean = latents_mean
|
||||||
|
self.latents_std = latents_std
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.shift_factor = shift_factor
|
||||||
|
self.sample_size = sample_size
|
||||||
|
self.force_upcast = force_upcast
|
||||||
|
|
||||||
|
def _encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
if self.quant_conv is not None:
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self._encode(x)
|
||||||
|
posterior = DiagonalGaussianDistribution(h)
|
||||||
|
return posterior
|
||||||
|
|
||||||
|
def _decode(self, z):
|
||||||
|
if self.post_quant_conv is not None:
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
return self._decode(z)
|
||||||
|
|
||||||
|
def forward(self, sample, sample_posterior=True, return_dict=True, generator=None):
|
||||||
|
posterior = self.encode(sample)
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample(generator=generator)
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
# Scale latent
|
||||||
|
z = z * self.scaling_factor
|
||||||
|
decode = self.decode(z)
|
||||||
|
if return_dict:
|
||||||
|
return {"sample": decode, "posterior": posterior, "latent_sample": z}
|
||||||
|
return decode, posterior
|
||||||
62
diffsynth/models/stable_diffusion_xl_text_encoder.py
Normal file
62
diffsynth/models/stable_diffusion_xl_text_encoder.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLTextEncoder2(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=1280,
|
||||||
|
intermediate_size=5120,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=20,
|
||||||
|
max_position_embeddings=77,
|
||||||
|
vocab_size=49408,
|
||||||
|
layer_norm_eps=1e-05,
|
||||||
|
hidden_act="gelu",
|
||||||
|
initializer_factor=1.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=1,
|
||||||
|
projection_dim=1280,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import CLIPTextConfig, CLIPTextModelWithProjection
|
||||||
|
|
||||||
|
config = CLIPTextConfig(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_hidden_layers=num_hidden_layers,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
layer_norm_eps=layer_norm_eps,
|
||||||
|
hidden_act=hidden_act,
|
||||||
|
initializer_factor=initializer_factor,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
projection_dim=projection_dim,
|
||||||
|
)
|
||||||
|
self.model = CLIPTextModelWithProjection(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
output_hidden_states=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if output_hidden_states:
|
||||||
|
return outputs.text_embeds, outputs.hidden_states
|
||||||
|
return outputs.text_embeds
|
||||||
922
diffsynth/models/stable_diffusion_xl_unet.py
Normal file
922
diffsynth/models/stable_diffusion_xl_unet.py
Normal file
@@ -0,0 +1,922 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Time Embedding =====
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
|
self.freq_shift = freq_shift
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
half_dim = self.num_channels // 2
|
||||||
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||||
|
exponent = exponent / half_dim + self.freq_shift
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
sin_emb = torch.sin(emb)
|
||||||
|
cos_emb = torch.cos(emb)
|
||||||
|
if self.flip_sin_to_cos:
|
||||||
|
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||||
|
else:
|
||||||
|
emb = torch.cat([sin_emb, cos_emb], dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||||
|
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
|
||||||
|
out_dim = out_dim if out_dim is not None else time_embed_dim
|
||||||
|
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
|
||||||
|
|
||||||
|
def forward(self, sample):
|
||||||
|
sample = self.linear_1(sample)
|
||||||
|
sample = self.act(sample)
|
||||||
|
sample = self.linear_2(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
# ===== ResNet Blocks =====
|
||||||
|
|
||||||
|
class ResnetBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout=0.0,
|
||||||
|
temb_channels=512,
|
||||||
|
groups=32,
|
||||||
|
groups_out=None,
|
||||||
|
pre_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
non_linearity="swish",
|
||||||
|
time_embedding_norm="default",
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
use_in_shortcut=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.pre_norm = pre_norm
|
||||||
|
self.time_embedding_norm = time_embedding_norm
|
||||||
|
self.output_scale_factor = output_scale_factor
|
||||||
|
|
||||||
|
if groups_out is None:
|
||||||
|
groups_out = groups
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
if temb_channels is not None:
|
||||||
|
if self.time_embedding_norm == "default":
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
|
||||||
|
elif self.time_embedding_norm == "scale_shift":
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
|
||||||
|
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
if non_linearity == "swish":
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
elif non_linearity == "silu":
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
elif non_linearity == "gelu":
|
||||||
|
self.nonlinearity = nn.GELU()
|
||||||
|
elif non_linearity == "relu":
|
||||||
|
self.nonlinearity = nn.ReLU()
|
||||||
|
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
self.conv_shortcut = None
|
||||||
|
if conv_shortcut:
|
||||||
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
else:
|
||||||
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
|
||||||
|
|
||||||
|
def forward(self, input_tensor, temb=None):
|
||||||
|
hidden_states = input_tensor
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.conv1(hidden_states)
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
temb = self.nonlinearity(temb)
|
||||||
|
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
if temb is not None and self.time_embedding_norm == "default":
|
||||||
|
hidden_states = hidden_states + temb
|
||||||
|
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
|
||||||
|
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||||
|
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||||
|
hidden_states = hidden_states * (1 + scale) + shift
|
||||||
|
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.conv2(hidden_states)
|
||||||
|
|
||||||
|
if self.conv_shortcut is not None:
|
||||||
|
input_tensor = self.conv_shortcut(input_tensor)
|
||||||
|
|
||||||
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Transformer Blocks =====
|
||||||
|
|
||||||
|
class GEGLU(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||||
|
return hidden_states * F.gelu(gate)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, dim_out=None, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.ModuleList([
|
||||||
|
GEGLU(dim, dim * 4),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
for module in self.net:
|
||||||
|
hidden_states = module(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=False,
|
||||||
|
upcast_attention=False,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.heads = heads
|
||||||
|
self.inner_dim = inner_dim
|
||||||
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||||
|
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
||||||
|
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
||||||
|
self.to_out = nn.ModuleList([
|
||||||
|
nn.Linear(inner_dim, query_dim, bias=True),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
query = self.to_q(hidden_states)
|
||||||
|
batch_size, seq_len, _ = query.shape
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
key = self.to_k(encoder_hidden_states)
|
||||||
|
value = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
head_dim = self.inner_dim // self.heads
|
||||||
|
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
hidden_states = self.to_out[0](hidden_states)
|
||||||
|
hidden_states = self.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
dropout=0.0,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
upcast_attention=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.attn1 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
bias=False,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
self.attn2 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
bias=False,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
|
self.ff = FeedForward(dim, dropout=dropout)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
attn_output = self.attn1(self.norm1(hidden_states))
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
ff_output = self.ff(self.norm3(hidden_states))
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer2DModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_attention_heads=16,
|
||||||
|
attention_head_dim=64,
|
||||||
|
in_channels=320,
|
||||||
|
num_layers=1,
|
||||||
|
dropout=0.0,
|
||||||
|
norm_num_groups=32,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
upcast_attention=False,
|
||||||
|
use_linear_projection=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.use_linear_projection = use_linear_projection
|
||||||
|
|
||||||
|
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
|
||||||
|
|
||||||
|
if use_linear_projection:
|
||||||
|
self.proj_in = nn.Linear(in_channels, inner_dim, bias=True)
|
||||||
|
else:
|
||||||
|
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, bias=True)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList([
|
||||||
|
BasicTransformerBlock(
|
||||||
|
dim=inner_dim,
|
||||||
|
n_heads=num_attention_heads,
|
||||||
|
d_head=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
if use_linear_projection:
|
||||||
|
self.proj_out = nn.Linear(inner_dim, in_channels, bias=True)
|
||||||
|
else:
|
||||||
|
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, bias=True)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
batch, channel, height, width = hidden_states.shape
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
if self.use_linear_projection:
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
|
||||||
|
hidden_states = self.proj_in(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = self.proj_in(hidden_states)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
|
||||||
|
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
if self.use_linear_projection:
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
|
||||||
|
else:
|
||||||
|
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Down/Up Blocks =====
|
||||||
|
|
||||||
|
class CrossAttnDownBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=1,
|
||||||
|
downsample=True,
|
||||||
|
use_linear_projection=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = True
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels_i = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels_i,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
attentions.append(
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=attention_head_dim,
|
||||||
|
attention_head_dim=out_channels // attention_head_dim,
|
||||||
|
in_channels=out_channels,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=resnet_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if downsample:
|
||||||
|
self.downsamplers = nn.ModuleList([
|
||||||
|
Downsample2D(out_channels, out_channels, padding=1)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.downsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
output_states = []
|
||||||
|
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, tuple(output_states)
|
||||||
|
|
||||||
|
|
||||||
|
class DownBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
downsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = False
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels_i = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels_i,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if downsample:
|
||||||
|
self.downsamplers = nn.ModuleList([
|
||||||
|
Downsample2D(out_channels, out_channels, padding=1)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.downsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
output_states = []
|
||||||
|
for resnet in self.resnets:
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
output_states.append(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, tuple(output_states)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttnUpBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
prev_output_channel,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=1,
|
||||||
|
upsample=True,
|
||||||
|
use_linear_projection=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = True
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||||
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||||
|
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=resnet_in_channels + res_skip_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
attentions.append(
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=attention_head_dim,
|
||||||
|
attention_head_dim=out_channels // attention_head_dim,
|
||||||
|
in_channels=out_channels,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=resnet_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if upsample:
|
||||||
|
self.upsamplers = nn.ModuleList([
|
||||||
|
Upsample2D(out_channels, out_channels)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.upsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class UpBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
prev_output_channel,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
upsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_cross_attention = False
|
||||||
|
|
||||||
|
resnets = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||||
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||||
|
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=resnet_in_channels + res_skip_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if upsample:
|
||||||
|
self.upsamplers = nn.ModuleList([
|
||||||
|
Upsample2D(out_channels, out_channels)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.upsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
||||||
|
for resnet in self.resnets:
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# ===== UNet Mid Block =====
|
||||||
|
|
||||||
|
class UNetMidBlock2DCrossAttn(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
temb_channels=1280,
|
||||||
|
dropout=0.0,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
resnet_act_fn="swish",
|
||||||
|
resnet_groups=32,
|
||||||
|
resnet_pre_norm=True,
|
||||||
|
cross_attention_dim=768,
|
||||||
|
attention_head_dim=1,
|
||||||
|
use_linear_projection=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||||
|
|
||||||
|
resnets = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for _ in range(num_layers):
|
||||||
|
attentions.append(
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=attention_head_dim,
|
||||||
|
attention_head_dim=in_channels // attention_head_dim,
|
||||||
|
in_channels=in_channels,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=resnet_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
hidden_states = self.resnets[0](hidden_states, temb)
|
||||||
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Downsample / Upsample =====
|
||||||
|
|
||||||
|
class Downsample2D(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, padding=1):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if self.padding == 0:
|
||||||
|
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample2D(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, upsample_size=None):
|
||||||
|
if upsample_size is not None:
|
||||||
|
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
|
||||||
|
else:
|
||||||
|
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== SDXL UNet2DConditionModel =====
|
||||||
|
|
||||||
|
class SDXLUNet2DConditionModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_size=128,
|
||||||
|
in_channels=4,
|
||||||
|
out_channels=4,
|
||||||
|
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
|
||||||
|
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
|
||||||
|
block_out_channels=(320, 640, 1280),
|
||||||
|
layers_per_block=2,
|
||||||
|
cross_attention_dim=2048,
|
||||||
|
attention_head_dim=5,
|
||||||
|
transformer_layers_per_block=1,
|
||||||
|
norm_num_groups=32,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
dropout=0.0,
|
||||||
|
act_fn="silu",
|
||||||
|
time_embedding_type="positional",
|
||||||
|
flip_sin_to_cos=True,
|
||||||
|
freq_shift=0,
|
||||||
|
time_embedding_dim=None,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
upcast_attention=False,
|
||||||
|
use_linear_projection=False,
|
||||||
|
addition_embed_type=None,
|
||||||
|
addition_time_embed_dim=None,
|
||||||
|
projection_class_embeddings_input_dim=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.sample_size = sample_size
|
||||||
|
self.addition_embed_type = addition_embed_type
|
||||||
|
|
||||||
|
if isinstance(attention_head_dim, int):
|
||||||
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||||
|
if isinstance(transformer_layers_per_block, int):
|
||||||
|
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||||
|
|
||||||
|
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
|
||||||
|
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
|
||||||
|
time_embed_dim = block_out_channels[0] * 4
|
||||||
|
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
|
||||||
|
|
||||||
|
if addition_embed_type == "text_time":
|
||||||
|
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
|
||||||
|
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||||
|
|
||||||
|
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList()
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
for i, down_block_type in enumerate(down_block_types):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
if "CrossAttn" in down_block_type:
|
||||||
|
down_block = CrossAttnDownBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim[i],
|
||||||
|
downsample=not is_final_block,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
down_block = DownBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
downsample=not is_final_block,
|
||||||
|
)
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=1,
|
||||||
|
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim[-1],
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.up_blocks = nn.ModuleList()
|
||||||
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||||
|
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||||
|
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
||||||
|
output_channel = reversed_block_out_channels[0]
|
||||||
|
|
||||||
|
for i, up_block_type in enumerate(up_block_types):
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
output_channel = reversed_block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||||
|
|
||||||
|
if "CrossAttn" in up_block_type:
|
||||||
|
up_block = CrossAttnUpBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block + 1,
|
||||||
|
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=reversed_attention_head_dim[i],
|
||||||
|
upsample=not is_final_block,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
up_block = UpBlock2D(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=layers_per_block + 1,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
upsample=not is_final_block,
|
||||||
|
)
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||||
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
t_emb = t_emb.to(dtype=sample.dtype)
|
||||||
|
emb = self.time_embedding(t_emb)
|
||||||
|
|
||||||
|
if self.addition_embed_type == "text_time":
|
||||||
|
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||||
|
time_ids = added_cond_kwargs.get("time_ids")
|
||||||
|
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||||
|
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||||
|
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||||
|
add_embeds = add_embeds.to(emb.dtype)
|
||||||
|
aug_emb = self.add_embedding(add_embeds)
|
||||||
|
emb = emb + aug_emb
|
||||||
|
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for down_block in self.down_blocks:
|
||||||
|
sample, res_samples = down_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
res_samples = down_block_res_samples[-len(up_block.resnets):]
|
||||||
|
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
|
||||||
|
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
|
||||||
|
sample = up_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sample,)
|
||||||
|
return sample
|
||||||
282
diffsynth/pipelines/joyai_image.py
Normal file
282
diffsynth/pipelines/joyai_image.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Union, Optional
|
||||||
|
from tqdm import tqdm
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..diffusion import FlowMatchScheduler
|
||||||
|
from ..core import ModelConfig
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
from ..models.joyai_image_dit import JoyAIImageDiT
|
||||||
|
from ..models.joyai_image_text_encoder import JoyAIImageTextEncoder
|
||||||
|
from ..models.wan_video_vae import WanVideoVAE
|
||||||
|
|
||||||
|
class JoyAIImagePipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler("Wan")
|
||||||
|
self.text_encoder: JoyAIImageTextEncoder = None
|
||||||
|
self.dit: JoyAIImageDiT = None
|
||||||
|
self.vae: WanVideoVAE = None
|
||||||
|
self.processor = None
|
||||||
|
self.in_iteration_models = ("dit",)
|
||||||
|
|
||||||
|
self.units = [
|
||||||
|
JoyAIImageUnit_ShapeChecker(),
|
||||||
|
JoyAIImageUnit_EditImageEmbedder(),
|
||||||
|
JoyAIImageUnit_PromptEmbedder(),
|
||||||
|
JoyAIImageUnit_NoiseInitializer(),
|
||||||
|
JoyAIImageUnit_InputImageEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_joyai_image
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
# Processor
|
||||||
|
processor_config: ModelConfig = None,
|
||||||
|
# Optional
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
pipe = JoyAIImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("joyai_image_text_encoder")
|
||||||
|
pipe.dit = model_pool.fetch_model("joyai_image_dit")
|
||||||
|
pipe.vae = model_pool.fetch_model("wan_video_vae")
|
||||||
|
|
||||||
|
if processor_config is not None:
|
||||||
|
processor_config.download_if_necessary()
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
pipe.processor = AutoProcessor.from_pretrained(processor_config.path)
|
||||||
|
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 5.0,
|
||||||
|
# Image
|
||||||
|
edit_image: Image.Image = None,
|
||||||
|
denoising_strength: float = 1.0,
|
||||||
|
# Shape
|
||||||
|
height: int = 1024,
|
||||||
|
width: int = 1024,
|
||||||
|
# Randomness
|
||||||
|
seed: int = None,
|
||||||
|
# Steps
|
||||||
|
max_sequence_length: int = 4096,
|
||||||
|
num_inference_steps: int = 30,
|
||||||
|
# Tiling
|
||||||
|
tiled: Optional[bool] = False,
|
||||||
|
tile_size: Optional[tuple[int, int]] = (30, 52),
|
||||||
|
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
||||||
|
# Scheduler
|
||||||
|
shift: Optional[float] = 4.0,
|
||||||
|
# Progress bar
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=shift)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
inputs_posi = {"prompt": prompt}
|
||||||
|
inputs_nega = {"negative_prompt": negative_prompt}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"edit_image": edit_image,
|
||||||
|
"denoising_strength": denoising_strength,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "max_sequence_length": max_sequence_length,
|
||||||
|
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Unit chain
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||||
|
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||||
|
)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
latents = rearrange(inputs_shared["latents"], "b n c f h w -> (b n) c f h w")
|
||||||
|
image = self.vae.decode(latents, device=self.device)[0]
|
||||||
|
image = self.vae_output_to_image(image, pattern="C 1 H W")
|
||||||
|
self.load_models_to_device([])
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageUnit_ShapeChecker(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("height", "width"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: "JoyAIImagePipeline", height, width):
|
||||||
|
height, width = pipe.check_resize_height_width(height, width)
|
||||||
|
return {"height": height, "width": width}
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
prompt_template_encode = {
|
||||||
|
'image':
|
||||||
|
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
'multiple_images':
|
||||||
|
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n",
|
||||||
|
'video':
|
||||||
|
"<|im_start|>system\n \\nDescribe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
}
|
||||||
|
prompt_template_encode_start_idx = {'image': 34, 'multiple_images': 34, 'video': 91}
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt", "positive": "positive"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
|
||||||
|
input_params=("edit_image", "max_sequence_length"),
|
||||||
|
output_params=("prompt_embeds", "prompt_embeds_mask"),
|
||||||
|
onload_model_names=("joyai_image_text_encoder",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: "JoyAIImagePipeline", prompt, positive, edit_image, max_sequence_length):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
|
||||||
|
has_image = edit_image is not None
|
||||||
|
|
||||||
|
if has_image:
|
||||||
|
prompt_embeds, prompt_embeds_mask = self._encode_with_image(pipe, prompt, edit_image, max_sequence_length)
|
||||||
|
else:
|
||||||
|
prompt_embeds, prompt_embeds_mask = self._encode_text_only(pipe, prompt, max_sequence_length)
|
||||||
|
|
||||||
|
return {"prompt_embeds": prompt_embeds, "prompt_embeds_mask": prompt_embeds_mask}
|
||||||
|
|
||||||
|
def _encode_with_image(self, pipe, prompt, edit_image, max_sequence_length):
|
||||||
|
template = self.prompt_template_encode['multiple_images']
|
||||||
|
drop_idx = self.prompt_template_encode_start_idx['multiple_images']
|
||||||
|
|
||||||
|
image_tokens = '<image>\n'
|
||||||
|
prompt = f"<|im_start|>user\n{image_tokens}{prompt}<|im_end|>\n"
|
||||||
|
prompt = prompt.replace('<image>\n', '<|vision_start|><|image_pad|><|vision_end|>')
|
||||||
|
prompt = template.format(prompt)
|
||||||
|
inputs = pipe.processor(text=[prompt], images=[edit_image], padding=True, return_tensors="pt").to(pipe.device)
|
||||||
|
last_hidden_states = pipe.text_encoder(**inputs)
|
||||||
|
|
||||||
|
prompt_embeds = last_hidden_states[:, drop_idx:]
|
||||||
|
prompt_embeds_mask = inputs['attention_mask'][:, drop_idx:]
|
||||||
|
|
||||||
|
if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length:
|
||||||
|
prompt_embeds = prompt_embeds[:, -max_sequence_length:, :]
|
||||||
|
prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:]
|
||||||
|
|
||||||
|
return prompt_embeds, prompt_embeds_mask
|
||||||
|
|
||||||
|
def _encode_text_only(self, pipe, prompt, max_sequence_length):
|
||||||
|
# TODO: may support for text-only encoding in the future.
|
||||||
|
raise NotImplementedError("Text-only encoding is not implemented yet. Please provide edit_image for now.")
|
||||||
|
return prompt_embeds, encoder_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageUnit_EditImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("edit_image", "tiled", "tile_size", "tile_stride", "height", "width"),
|
||||||
|
output_params=("ref_latents", "num_items", "is_multi_item"),
|
||||||
|
onload_model_names=("wan_video_vae",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: "JoyAIImagePipeline", edit_image, tiled, tile_size, tile_stride, height, width):
|
||||||
|
if edit_image is None:
|
||||||
|
return {}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
# Resize edit image to match target dimensions (from ShapeChecker) to ensure ref_latents matches latents
|
||||||
|
edit_image = edit_image.resize((width, height), Image.LANCZOS)
|
||||||
|
images = [pipe.preprocess_image(edit_image).transpose(0, 1)]
|
||||||
|
latents = pipe.vae.encode(images, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
ref_vae = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=1).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||||
|
|
||||||
|
return {"ref_latents": ref_vae, "edit_image": edit_image}
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("seed", "height", "width", "rand_device"),
|
||||||
|
output_params=("noise"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: "JoyAIImagePipeline", seed, height, width, rand_device):
|
||||||
|
latent_h = height // pipe.vae.upsampling_factor
|
||||||
|
latent_w = width // pipe.vae.upsampling_factor
|
||||||
|
shape = (1, 1, pipe.vae.z_dim, 1, latent_h, latent_w)
|
||||||
|
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: JoyAIImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
if isinstance(input_image, Image.Image):
|
||||||
|
input_image = [input_image]
|
||||||
|
input_image = [pipe.preprocess_image(img).transpose(0, 1) for img in input_image]
|
||||||
|
latents = pipe.vae.encode(input_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
input_latents = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(input_image)))
|
||||||
|
return {"latents": noise, "input_latents": input_latents}
|
||||||
|
|
||||||
|
def model_fn_joyai_image(
|
||||||
|
dit,
|
||||||
|
latents,
|
||||||
|
timestep,
|
||||||
|
prompt_embeds,
|
||||||
|
prompt_embeds_mask,
|
||||||
|
ref_latents=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
img = torch.cat([ref_latents, latents], dim=1) if ref_latents is not None else latents
|
||||||
|
|
||||||
|
img = dit(
|
||||||
|
hidden_states=img,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
|
||||||
|
img = img[:, -latents.size(1):]
|
||||||
|
return img
|
||||||
230
diffsynth/pipelines/stable_diffusion.py
Normal file
230
diffsynth/pipelines/stable_diffusion.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..diffusion.ddim_scheduler import DDIMScheduler
|
||||||
|
from ..core import ModelConfig
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, CLIPTextModel
|
||||||
|
from ..models.stable_diffusion_text_encoder import SDTextEncoder
|
||||||
|
from ..models.stable_diffusion_unet import UNet2DConditionModel
|
||||||
|
from ..models.stable_diffusion_vae import StableDiffusionVAE
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionPipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device=get_device_type(), torch_dtype=torch.float16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=8, width_division_factor=8,
|
||||||
|
)
|
||||||
|
self.scheduler = DDIMScheduler()
|
||||||
|
self.text_encoder: SDTextEncoder = None
|
||||||
|
self.unet: UNet2DConditionModel = None
|
||||||
|
self.vae: StableDiffusionVAE = None
|
||||||
|
self.tokenizer: AutoTokenizer = None
|
||||||
|
|
||||||
|
self.in_iteration_models = ("unet",)
|
||||||
|
self.units = [
|
||||||
|
SDUnit_ShapeChecker(),
|
||||||
|
SDUnit_PromptEmbedder(),
|
||||||
|
SDUnit_NoiseInitializer(),
|
||||||
|
SDUnit_InputImageEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_stable_diffusion
|
||||||
|
self.compilable_models = ["unet"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.float16,
|
||||||
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
pipe = StableDiffusionPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
# Override vram_config to use the specified torch_dtype for all models
|
||||||
|
for mc in model_configs:
|
||||||
|
mc._vram_config_override = {
|
||||||
|
'onload_dtype': torch_dtype,
|
||||||
|
'computation_dtype': torch_dtype,
|
||||||
|
}
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder")
|
||||||
|
pipe.unet = model_pool.fetch_model("stable_diffusion_unet")
|
||||||
|
pipe.vae = model_pool.fetch_model("stable_diffusion_vae")
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 7.5,
|
||||||
|
height: int = 512,
|
||||||
|
width: int = 512,
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cpu",
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
eta: float = 0.0,
|
||||||
|
guidance_rescale: float = 0.0,
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
):
|
||||||
|
# 1. Scheduler
|
||||||
|
self.scheduler.set_timesteps(
|
||||||
|
num_inference_steps, eta=eta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Three-dict input preparation
|
||||||
|
inputs_posi = {"prompt": prompt}
|
||||||
|
inputs_nega = {"negative_prompt": negative_prompt}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"guidance_rescale": guidance_rescale,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. Unit chain execution
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||||
|
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Denoise loop
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(
|
||||||
|
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. VAE decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
latents = inputs_shared["latents"] / self.vae.scaling_factor
|
||||||
|
image = self.vae.decode(latents)
|
||||||
|
image = self.vae_output_to_image(image)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class SDUnit_ShapeChecker(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("height", "width"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionPipeline, height, width):
|
||||||
|
height, width = pipe.check_resize_height_width(height, width)
|
||||||
|
return {"height": height, "width": width}
|
||||||
|
|
||||||
|
|
||||||
|
class SDUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
|
output_params=("prompt_embeds",),
|
||||||
|
onload_model_names=("text_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
pipe: StableDiffusionPipeline,
|
||||||
|
prompt: str,
|
||||||
|
device: torch.device,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
text_inputs = pipe.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=pipe.tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
text_input_ids = text_inputs.input_ids.to(device)
|
||||||
|
prompt_embeds = pipe.text_encoder(text_input_ids)
|
||||||
|
# TextEncoder returns (last_hidden_state, hidden_states) or just last_hidden_state.
|
||||||
|
# last_hidden_state is the post-final-layer-norm output, matching diffusers encode_prompt.
|
||||||
|
if isinstance(prompt_embeds, tuple):
|
||||||
|
prompt_embeds = prompt_embeds[0]
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionPipeline, prompt):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||||
|
return {"prompt_embeds": prompt_embeds}
|
||||||
|
|
||||||
|
|
||||||
|
class SDUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionPipeline, height, width, seed, rand_device):
|
||||||
|
noise = pipe.generate_noise(
|
||||||
|
(1, pipe.unet.in_channels, height // 8, width // 8),
|
||||||
|
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
|
||||||
|
)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class SDUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionPipeline, input_image, noise):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
input_tensor = pipe.preprocess_image(input_image)
|
||||||
|
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
|
||||||
|
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
return {"latents": latents}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_stable_diffusion(
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
latents=None,
|
||||||
|
timestep=None,
|
||||||
|
prompt_embeds=None,
|
||||||
|
cross_attention_kwargs=None,
|
||||||
|
timestep_cond=None,
|
||||||
|
added_cond_kwargs=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# SD timestep is already in 0-999 range, no scaling needed
|
||||||
|
noise_pred = unet(
|
||||||
|
latents,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
timestep_cond=timestep_cond,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
return noise_pred
|
||||||
331
diffsynth/pipelines/stable_diffusion_xl.py
Normal file
331
diffsynth/pipelines/stable_diffusion_xl.py
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..diffusion.ddim_scheduler import DDIMScheduler
|
||||||
|
from ..core import ModelConfig
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, CLIPTextModel
|
||||||
|
from ..models.stable_diffusion_text_encoder import SDTextEncoder
|
||||||
|
from ..models.stable_diffusion_xl_unet import SDXLUNet2DConditionModel
|
||||||
|
from ..models.stable_diffusion_xl_text_encoder import SDXLTextEncoder2
|
||||||
|
from ..models.stable_diffusion_vae import StableDiffusionVAE
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||||
|
"""Rescale noise_cfg based on guidance_rescale to prevent overexposure.
|
||||||
|
|
||||||
|
Based on Section 3.4 from "Common Diffusion Noise Schedules and Sample Steps are Flawed"
|
||||||
|
https://huggingface.co/papers/2305.08891
|
||||||
|
"""
|
||||||
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||||
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||||
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||||
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||||
|
return noise_cfg
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLPipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=8, width_division_factor=8,
|
||||||
|
)
|
||||||
|
self.scheduler = DDIMScheduler()
|
||||||
|
self.text_encoder: SDTextEncoder = None
|
||||||
|
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||||
|
self.unet: SDXLUNet2DConditionModel = None
|
||||||
|
self.vae: StableDiffusionVAE = None
|
||||||
|
self.tokenizer: AutoTokenizer = None
|
||||||
|
self.tokenizer_2: AutoTokenizer = None
|
||||||
|
|
||||||
|
self.in_iteration_models = ("unet",)
|
||||||
|
self.units = [
|
||||||
|
SDXLUnit_ShapeChecker(),
|
||||||
|
SDXLUnit_PromptEmbedder(),
|
||||||
|
SDXLUnit_NoiseInitializer(),
|
||||||
|
SDXLUnit_InputImageEmbedder(),
|
||||||
|
SDXLUnit_AddTimeIdsComputer(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_stable_diffusion_xl
|
||||||
|
self.compilable_models = ["unet"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = None,
|
||||||
|
tokenizer_2_config: ModelConfig = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
pipe = StableDiffusionXLPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
# Override vram_config to use the specified torch_dtype for all models
|
||||||
|
for mc in model_configs:
|
||||||
|
mc._vram_config_override = {
|
||||||
|
'onload_dtype': torch_dtype,
|
||||||
|
'computation_dtype': torch_dtype,
|
||||||
|
}
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder")
|
||||||
|
pipe.text_encoder_2 = model_pool.fetch_model("stable_diffusion_xl_text_encoder")
|
||||||
|
pipe.unet = model_pool.fetch_model("stable_diffusion_xl_unet")
|
||||||
|
pipe.vae = model_pool.fetch_model("stable_diffusion_xl_vae")
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
|
if tokenizer_2_config is not None:
|
||||||
|
tokenizer_2_config.download_if_necessary()
|
||||||
|
pipe.tokenizer_2 = AutoTokenizer.from_pretrained(tokenizer_2_config.path)
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 5.0,
|
||||||
|
height: int = 1024,
|
||||||
|
width: int = 1024,
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cpu",
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
guidance_rescale: float = 0.0,
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
):
|
||||||
|
# 1. Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
# 2. Three-dict input preparation
|
||||||
|
inputs_posi = {
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
inputs_nega = {
|
||||||
|
"prompt": negative_prompt,
|
||||||
|
}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"guidance_rescale": guidance_rescale,
|
||||||
|
"crops_coords_top_left": (0, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. Unit chain execution
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||||
|
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Denoise loop
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply guidance_rescale
|
||||||
|
if guidance_rescale > 0.0:
|
||||||
|
# cfg_guided_model_fn already applied CFG, now apply rescale
|
||||||
|
# We need the text-only prediction for rescale
|
||||||
|
noise_pred_text = self.model_fn(
|
||||||
|
self.unet,
|
||||||
|
inputs_shared["latents"],
|
||||||
|
timestep,
|
||||||
|
inputs_posi["prompt_embeds"],
|
||||||
|
pooled_prompt_embeds=inputs_posi["pooled_prompt_embeds"],
|
||||||
|
add_time_ids=inputs_posi["add_time_ids"],
|
||||||
|
)
|
||||||
|
noise_pred = rescale_noise_cfg(
|
||||||
|
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_shared["latents"] = self.step(
|
||||||
|
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. VAE decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
latents = inputs_shared["latents"] / self.vae.scaling_factor
|
||||||
|
image = self.vae.decode(latents)
|
||||||
|
image = self.vae_output_to_image(image)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUnit_ShapeChecker(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("height", "width"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionXLPipeline, height, width):
|
||||||
|
height, width = pipe.check_resize_height_width(height, width)
|
||||||
|
return {"height": height, "width": width}
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "prompt"},
|
||||||
|
output_params=("prompt_embeds", "pooled_prompt_embeds"),
|
||||||
|
onload_model_names=("text_encoder", "text_encoder_2")
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
pipe: StableDiffusionXLPipeline,
|
||||||
|
prompt: str,
|
||||||
|
device: torch.device,
|
||||||
|
) -> tuple:
|
||||||
|
"""Encode prompt using both text encoders (same prompt for both).
|
||||||
|
|
||||||
|
Returns (prompt_embeds, pooled_prompt_embeds):
|
||||||
|
- prompt_embeds: concat(encoder1_output, encoder2_output) -> (B, 77, 2048)
|
||||||
|
- pooled_prompt_embeds: encoder2 pooled output -> (B, 1280)
|
||||||
|
"""
|
||||||
|
# Text Encoder 1 (CLIP-L, 768-dim)
|
||||||
|
text_input_ids_1 = pipe.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=pipe.tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_ids.to(device)
|
||||||
|
prompt_embeds_1 = pipe.text_encoder(text_input_ids_1)
|
||||||
|
if isinstance(prompt_embeds_1, tuple):
|
||||||
|
prompt_embeds_1 = prompt_embeds_1[0]
|
||||||
|
|
||||||
|
# Text Encoder 2 (CLIP-bigG, 1280-dim) — uses penultimate hidden states + pooled
|
||||||
|
text_input_ids_2 = pipe.tokenizer_2(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=pipe.tokenizer_2.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_ids.to(device)
|
||||||
|
# SDXLTextEncoder2 forward returns (text_embeds/pooled, hidden_states_tuple)
|
||||||
|
pooled_prompt_embeds, hidden_states = pipe.text_encoder_2(text_input_ids_2, output_hidden_states=True)
|
||||||
|
# Use penultimate hidden state (same as diffusers: hidden_states[-2])
|
||||||
|
prompt_embeds_2 = hidden_states[-2]
|
||||||
|
|
||||||
|
# Concatenate both encoder outputs along feature dimension
|
||||||
|
prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
|
||||||
|
|
||||||
|
return prompt_embeds, pooled_prompt_embeds
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionXLPipeline, prompt):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||||
|
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionXLPipeline, height, width, seed, rand_device):
|
||||||
|
noise = pipe.generate_noise(
|
||||||
|
(1, pipe.unet.in_channels, height // 8, width // 8),
|
||||||
|
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
|
||||||
|
)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionXLPipeline, input_image, noise):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
input_tensor = pipe.preprocess_image(input_image)
|
||||||
|
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
|
||||||
|
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
return {"latents": latents}
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUnit_AddTimeIdsComputer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("add_time_ids",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_add_time_ids(self, pipe, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim):
|
||||||
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||||
|
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
|
||||||
|
addition_time_embed_dim = pipe.unet.add_time_proj.num_channels
|
||||||
|
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||||
|
if expected_add_embed_dim != passed_add_embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
|
||||||
|
f"but a vector of {passed_add_embed_dim} was created."
|
||||||
|
)
|
||||||
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=pipe.device)
|
||||||
|
return add_time_ids
|
||||||
|
|
||||||
|
def process(self, pipe: StableDiffusionXLPipeline, height, width):
|
||||||
|
original_size = (height, width)
|
||||||
|
target_size = (height, width)
|
||||||
|
crops_coords_top_left = (0, 0)
|
||||||
|
|
||||||
|
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
|
||||||
|
add_time_ids = self._get_add_time_ids(
|
||||||
|
pipe, original_size, crops_coords_top_left, target_size,
|
||||||
|
dtype=pipe.torch_dtype,
|
||||||
|
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||||
|
)
|
||||||
|
return {"add_time_ids": add_time_ids}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_stable_diffusion_xl(
|
||||||
|
unet: SDXLUNet2DConditionModel,
|
||||||
|
latents=None,
|
||||||
|
timestep=None,
|
||||||
|
prompt_embeds=None,
|
||||||
|
pooled_prompt_embeds=None,
|
||||||
|
add_time_ids=None,
|
||||||
|
cross_attention_kwargs=None,
|
||||||
|
timestep_cond=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""SDXL model forward with added_cond_kwargs for micro-conditioning."""
|
||||||
|
added_cond_kwargs = {
|
||||||
|
"text_embeds": pooled_prompt_embeds,
|
||||||
|
"time_ids": add_time_ids,
|
||||||
|
}
|
||||||
|
noise_pred = unet(
|
||||||
|
latents,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
timestep_cond=timestep_cond,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
return noise_pred
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
def JoyAIImageTextEncoderStateDictConverter(state_dict):
|
||||||
|
"""Convert HuggingFace Qwen3VL checkpoint keys to DiffSynth wrapper keys.
|
||||||
|
|
||||||
|
Mapping (checkpoint -> wrapper):
|
||||||
|
- lm_head.weight -> model.lm_head.weight
|
||||||
|
- model.language_model.* -> model.model.language_model.*
|
||||||
|
- model.visual.* -> model.model.visual.*
|
||||||
|
"""
|
||||||
|
state_dict_ = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key == "lm_head.weight":
|
||||||
|
new_key = "model.lm_head.weight"
|
||||||
|
elif key.startswith("model.language_model."):
|
||||||
|
new_key = "model.model." + key[len("model."):]
|
||||||
|
elif key.startswith("model.visual."):
|
||||||
|
new_key = "model.model." + key[len("model."):]
|
||||||
|
else:
|
||||||
|
new_key = key
|
||||||
|
state_dict_[new_key] = state_dict[key]
|
||||||
|
return state_dict_
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
def SDTextEncoderStateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith("text_model.") and "position_ids" not in key:
|
||||||
|
new_key = "model." + key
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
return new_state_dict
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
def SDVAEStateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if ".query." in key:
|
||||||
|
new_key = key.replace(".query.", ".to_q.")
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
elif ".key." in key:
|
||||||
|
new_key = key.replace(".key.", ".to_k.")
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
elif ".value." in key:
|
||||||
|
new_key = key.replace(".value.", ".to_v.")
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
elif ".proj_attn." in key:
|
||||||
|
new_key = key.replace(".proj_attn.", ".to_out.0.")
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
else:
|
||||||
|
new_state_dict[key] = state_dict[key]
|
||||||
|
return new_state_dict
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def SDXLTextEncoder2StateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key == "text_projection.weight":
|
||||||
|
val = state_dict[key]
|
||||||
|
new_state_dict["model.text_projection.weight"] = val.float() if val.dtype == torch.float16 else val
|
||||||
|
elif key.startswith("text_model.") and "position_ids" not in key:
|
||||||
|
new_key = "model." + key
|
||||||
|
val = state_dict[key]
|
||||||
|
new_state_dict[new_key] = val.float() if val.dtype == torch.float16 else val
|
||||||
|
return new_state_dict
|
||||||
154
docs/en/Model_Details/JoyAI-Image.md
Normal file
154
docs/en/Model_Details/JoyAI-Image.md
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# JoyAI-Image
|
||||||
|
|
||||||
|
JoyAI-Image is a unified multi-modal foundation model open-sourced by JD.com, supporting image understanding, text-to-image generation, and instruction-guided image editing.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Before performing model inference and training, please install DiffSynth-Studio first.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Running the following code will load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 4GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
# Download dataset
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||||
|
)
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe = JoyAIImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use first sample from dataset
|
||||||
|
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||||
|
prompt = "将裙子改为粉色"
|
||||||
|
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||||
|
|
||||||
|
output = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=0,
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
output.save("output_joyai_edit_low_vram.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
|
||||||
|
|
||||||
|
## Model Inference
|
||||||
|
|
||||||
|
The model is loaded via `JoyAIImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||||
|
|
||||||
|
The input parameters for `JoyAIImagePipeline` inference include:
|
||||||
|
|
||||||
|
* `prompt`: Text prompt describing the desired image editing effect.
|
||||||
|
* `negative_prompt`: Negative prompt specifying what should not appear in the result, defaults to empty string.
|
||||||
|
* `cfg_scale`: Classifier-free guidance scale factor, defaults to 5.0. Higher values make the output more closely follow the prompt.
|
||||||
|
* `edit_image`: Image to be edited.
|
||||||
|
* `denoising_strength`: Denoising strength controlling how much the input image is repainted, defaults to 1.0.
|
||||||
|
* `height`: Height of the output image, defaults to 1024. Must be divisible by 16.
|
||||||
|
* `width`: Width of the output image, defaults to 1024. Must be divisible by 16.
|
||||||
|
* `seed`: Random seed for reproducibility. Set to `None` for random seed.
|
||||||
|
* `max_sequence_length`: Maximum sequence length for the text encoder, defaults to 4096.
|
||||||
|
* `num_inference_steps`: Number of inference steps, defaults to 30. More steps typically yield better quality.
|
||||||
|
* `tiled`: Whether to enable tiling for reduced VRAM usage, defaults to False.
|
||||||
|
* `tile_size`: Tile size, defaults to (30, 52).
|
||||||
|
* `tile_stride`: Tile stride, defaults to (15, 26).
|
||||||
|
* `shift`: Shift parameter for the scheduler, controlling the Flow Match scheduling curve, defaults to 4.0.
|
||||||
|
* `progress_bar_cmd`: Progress bar display mode, defaults to tqdm.
|
||||||
|
|
||||||
|
## Model Training
|
||||||
|
|
||||||
|
Models in the joyai_image series are trained uniformly via `examples/joyai_image/model_training/train.py`. The script parameters include:
|
||||||
|
|
||||||
|
* General Training Parameters
|
||||||
|
* Dataset Configuration
|
||||||
|
* `--dataset_base_path`: Root directory of the dataset.
|
||||||
|
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||||
|
* `--dataset_repeat`: Number of dataset repeats per epoch.
|
||||||
|
* `--dataset_num_workers`: Number of processes per DataLoader.
|
||||||
|
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
|
||||||
|
* Model Loading Configuration
|
||||||
|
* `--model_paths`: Paths to load models from, in JSON format.
|
||||||
|
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
|
||||||
|
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
|
||||||
|
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
|
||||||
|
* Basic Training Configuration
|
||||||
|
* `--learning_rate`: Learning rate.
|
||||||
|
* `--num_epochs`: Number of epochs.
|
||||||
|
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||||
|
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
|
||||||
|
* `--weight_decay`: Weight decay magnitude.
|
||||||
|
* `--task`: Training task, defaults to `sft`.
|
||||||
|
* Output Configuration
|
||||||
|
* `--output_path`: Path to save the model.
|
||||||
|
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
|
||||||
|
* `--save_steps`: Interval in training steps to save the model.
|
||||||
|
* LoRA Configuration
|
||||||
|
* `--lora_base_model`: Which model to add LoRA to.
|
||||||
|
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||||
|
* `--lora_rank`: Rank of LoRA.
|
||||||
|
* `--lora_checkpoint`: Path to LoRA checkpoint.
|
||||||
|
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
|
||||||
|
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
|
||||||
|
* Gradient Configuration
|
||||||
|
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||||
|
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||||
|
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||||
|
* Resolution Configuration
|
||||||
|
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
|
||||||
|
* `--num_frames`: Number of frames for video (video generation models only).
|
||||||
|
* JoyAI-Image Specific Parameters
|
||||||
|
* `--processor_path`: Path to the processor for processing text and image encoder inputs.
|
||||||
|
* `--initialize_model_on_cpu`: Whether to initialize models on CPU. By default, models are initialized on the accelerator device.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
141
docs/en/Model_Details/Stable-Diffusion-XL.md
Normal file
141
docs/en/Model_Details/Stable-Diffusion-XL.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Stable Diffusion XL
|
||||||
|
|
||||||
|
Stable Diffusion XL (SDXL) is an open-source diffusion-based text-to-image generation model developed by Stability AI, supporting 1024x1024 resolution high-quality text-to-image generation with a dual text encoder (CLIP-L + CLIP-bigG) architecture.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Before performing model inference and training, please install DiffSynth-Studio first.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Running the following code will quickly load the [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 6GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
|
||||||
|
|
||||||
|
## Model Inference
|
||||||
|
|
||||||
|
The model is loaded via `StableDiffusionXLPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||||
|
|
||||||
|
The input parameters for `StableDiffusionXLPipeline` inference include:
|
||||||
|
|
||||||
|
* `prompt`: Text prompt.
|
||||||
|
* `negative_prompt`: Negative prompt, defaults to an empty string.
|
||||||
|
* `cfg_scale`: Classifier-Free Guidance scale factor, default 5.0.
|
||||||
|
* `height`: Output image height, default 1024.
|
||||||
|
* `width`: Output image width, default 1024.
|
||||||
|
* `seed`: Random seed, defaults to a random value if not set.
|
||||||
|
* `rand_device`: Noise generation device, defaults to "cpu".
|
||||||
|
* `num_inference_steps`: Number of inference steps, default 50.
|
||||||
|
* `guidance_rescale`: Guidance rescale factor, default 0.0.
|
||||||
|
* `progress_bar_cmd`: Progress bar callback function.
|
||||||
|
|
||||||
|
> `StableDiffusionXLPipeline` requires dual tokenizer configurations (`tokenizer_config` and `tokenizer_2_config`), corresponding to the CLIP-L and CLIP-bigG text encoders.
|
||||||
|
|
||||||
|
## Model Training
|
||||||
|
|
||||||
|
Models in the stable_diffusion_xl series are trained via `examples/stable_diffusion_xl/model_training/train.py`. The script parameters include:
|
||||||
|
|
||||||
|
* General Training Parameters
|
||||||
|
* Dataset Configuration
|
||||||
|
* `--dataset_base_path`: Root directory of the dataset.
|
||||||
|
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||||
|
* `--dataset_repeat`: Number of dataset repeats per epoch.
|
||||||
|
* `--dataset_num_workers`: Number of processes per DataLoader.
|
||||||
|
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
|
||||||
|
* Model Loading Configuration
|
||||||
|
* `--model_paths`: Paths to load models from, in JSON format.
|
||||||
|
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
|
||||||
|
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
|
||||||
|
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
|
||||||
|
* Basic Training Configuration
|
||||||
|
* `--learning_rate`: Learning rate.
|
||||||
|
* `--num_epochs`: Number of epochs.
|
||||||
|
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||||
|
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
|
||||||
|
* `--weight_decay`: Weight decay magnitude.
|
||||||
|
* `--task`: Training task, defaults to `sft`.
|
||||||
|
* Output Configuration
|
||||||
|
* `--output_path`: Path to save the model.
|
||||||
|
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
|
||||||
|
* `--save_steps`: Interval in training steps to save the model.
|
||||||
|
* LoRA Configuration
|
||||||
|
* `--lora_base_model`: Which model to add LoRA to.
|
||||||
|
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||||
|
* `--lora_rank`: Rank of LoRA.
|
||||||
|
* `--lora_checkpoint`: Path to LoRA checkpoint.
|
||||||
|
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
|
||||||
|
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
|
||||||
|
* Gradient Configuration
|
||||||
|
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||||
|
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||||
|
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||||
|
* Resolution Configuration
|
||||||
|
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
|
||||||
|
* `--num_frames`: Number of frames for video (video generation models only).
|
||||||
|
* Stable Diffusion XL Specific Parameters
|
||||||
|
* `--tokenizer_path`: Path to the first tokenizer.
|
||||||
|
* `--tokenizer_2_path`: Path to the second tokenizer, defaults to `stabilityai/stable-diffusion-xl-base-1.0:tokenizer_2/`.
|
||||||
|
|
||||||
|
Example dataset download:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
[stable-diffusion-xl-base-1.0 training scripts](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)
|
||||||
|
|
||||||
|
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
138
docs/en/Model_Details/Stable-Diffusion.md
Normal file
138
docs/en/Model_Details/Stable-Diffusion.md
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
# Stable Diffusion
|
||||||
|
|
||||||
|
Stable Diffusion is an open-source diffusion-based text-to-image generation model developed by Stability AI, supporting 512x512 resolution text-to-image generation.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Before performing model inference and training, please install DiffSynth-Studio first.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Running the following code will quickly load the [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 2GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
|
||||||
|
|
||||||
|
## Model Inference
|
||||||
|
|
||||||
|
The model is loaded via `StableDiffusionPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||||
|
|
||||||
|
The input parameters for `StableDiffusionPipeline` inference include:
|
||||||
|
|
||||||
|
* `prompt`: Text prompt.
|
||||||
|
* `negative_prompt`: Negative prompt, defaults to an empty string.
|
||||||
|
* `cfg_scale`: Classifier-Free Guidance scale factor, default 7.5.
|
||||||
|
* `height`: Output image height, default 512.
|
||||||
|
* `width`: Output image width, default 512.
|
||||||
|
* `seed`: Random seed, defaults to a random value if not set.
|
||||||
|
* `rand_device`: Noise generation device, defaults to "cpu".
|
||||||
|
* `num_inference_steps`: Number of inference steps, default 50.
|
||||||
|
* `eta`: DDIM scheduler eta parameter, default 0.0.
|
||||||
|
* `guidance_rescale`: Guidance rescale factor, default 0.0.
|
||||||
|
* `progress_bar_cmd`: Progress bar callback function.
|
||||||
|
|
||||||
|
## Model Training
|
||||||
|
|
||||||
|
Models in the stable_diffusion series are trained via `examples/stable_diffusion/model_training/train.py`. The script parameters include:
|
||||||
|
|
||||||
|
* General Training Parameters
|
||||||
|
* Dataset Configuration
|
||||||
|
* `--dataset_base_path`: Root directory of the dataset.
|
||||||
|
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||||
|
* `--dataset_repeat`: Number of dataset repeats per epoch.
|
||||||
|
* `--dataset_num_workers`: Number of processes per DataLoader.
|
||||||
|
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
|
||||||
|
* Model Loading Configuration
|
||||||
|
* `--model_paths`: Paths to load models from, in JSON format.
|
||||||
|
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
|
||||||
|
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
|
||||||
|
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
|
||||||
|
* Basic Training Configuration
|
||||||
|
* `--learning_rate`: Learning rate.
|
||||||
|
* `--num_epochs`: Number of epochs.
|
||||||
|
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||||
|
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
|
||||||
|
* `--weight_decay`: Weight decay magnitude.
|
||||||
|
* `--task`: Training task, defaults to `sft`.
|
||||||
|
* Output Configuration
|
||||||
|
* `--output_path`: Path to save the model.
|
||||||
|
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
|
||||||
|
* `--save_steps`: Interval in training steps to save the model.
|
||||||
|
* LoRA Configuration
|
||||||
|
* `--lora_base_model`: Which model to add LoRA to.
|
||||||
|
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||||
|
* `--lora_rank`: Rank of LoRA.
|
||||||
|
* `--lora_checkpoint`: Path to LoRA checkpoint.
|
||||||
|
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
|
||||||
|
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
|
||||||
|
* Gradient Configuration
|
||||||
|
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||||
|
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||||
|
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||||
|
* Resolution Configuration
|
||||||
|
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
|
||||||
|
* `--num_frames`: Number of frames for video (video generation models only).
|
||||||
|
* Stable Diffusion Specific Parameters
|
||||||
|
* `--tokenizer_path`: Tokenizer path, defaults to `AI-ModelScope/stable-diffusion-v1-5:tokenizer/`.
|
||||||
|
|
||||||
|
Example dataset download:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
[stable-diffusion-v1-5 training scripts](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)
|
||||||
|
|
||||||
|
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
@@ -31,6 +31,9 @@ Welcome to DiffSynth-Studio's Documentation
|
|||||||
Model_Details/Anima
|
Model_Details/Anima
|
||||||
Model_Details/LTX-2
|
Model_Details/LTX-2
|
||||||
Model_Details/ERNIE-Image
|
Model_Details/ERNIE-Image
|
||||||
|
Model_Details/JoyAI-Image
|
||||||
|
Model_Details/Stable-Diffusion
|
||||||
|
Model_Details/Stable-Diffusion-XL
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|||||||
154
docs/zh/Model_Details/JoyAI-Image.md
Normal file
154
docs/zh/Model_Details/JoyAI-Image.md
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# JoyAI-Image
|
||||||
|
|
||||||
|
JoyAI-Image 是京东开源的统一多模态基础模型,支持图像理解、文生图生成和指令引导的图像编辑。
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
# Download dataset
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||||
|
)
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe = JoyAIImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use first sample from dataset
|
||||||
|
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||||
|
prompt = "将裙子改为粉色"
|
||||||
|
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||||
|
|
||||||
|
output = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=0,
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
output.save("output_joyai_edit_low_vram.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型总览
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
|
||||||
|
|
||||||
|
## 模型推理
|
||||||
|
|
||||||
|
模型通过 `JoyAIImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||||
|
|
||||||
|
`JoyAIImagePipeline` 推理的输入参数包括:
|
||||||
|
|
||||||
|
* `prompt`: 文本提示词,用于描述期望的图像编辑效果。
|
||||||
|
* `negative_prompt`: 负向提示词,指定不希望出现在结果中的内容,默认为空字符串。
|
||||||
|
* `cfg_scale`: 分类器自由引导的缩放系数,默认为 5.0。值越大,生成结果越贴近 prompt 描述。
|
||||||
|
* `edit_image`: 待编辑的单张图像。
|
||||||
|
* `denoising_strength`: 降噪强度,控制输入图像被重绘的程度,默认为 1.0。
|
||||||
|
* `height`: 输出图像的高度,默认为 1024。需能被 16 整除。
|
||||||
|
* `width`: 输出图像的宽度,默认为 1024。需能被 16 整除。
|
||||||
|
* `seed`: 随机种子,用于控制生成的可复现性。设为 `None` 时使用随机种子。
|
||||||
|
* `max_sequence_length`: 文本编码器处理的最大序列长度,默认为 4096。
|
||||||
|
* `num_inference_steps`: 推理步数,默认为 30。步数越多,生成质量通常越好。
|
||||||
|
* `tiled`: 是否启用分块处理,用于降低显存占用,默认为 False。
|
||||||
|
* `tile_size`: 分块大小,默认为 (30, 52)。
|
||||||
|
* `tile_stride`: 分块步幅,默认为 (15, 26)。
|
||||||
|
* `shift`: 调度器的 shift 参数,用于控制 Flow Match 的调度曲线,默认为 4.0。
|
||||||
|
* `progress_bar_cmd`: 进度条显示方式,默认为 tqdm。
|
||||||
|
|
||||||
|
## 模型训练
|
||||||
|
|
||||||
|
joyai_image 系列模型统一通过 `examples/joyai_image/model_training/train.py` 进行训练,脚本的参数包括:
|
||||||
|
|
||||||
|
* 通用训练参数
|
||||||
|
* 数据集基础配置
|
||||||
|
* `--dataset_base_path`: 数据集的根目录。
|
||||||
|
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||||
|
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||||
|
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
|
||||||
|
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||||
|
* 模型加载配置
|
||||||
|
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||||
|
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
|
||||||
|
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
|
||||||
|
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
|
||||||
|
* 训练基础配置
|
||||||
|
* `--learning_rate`: 学习率。
|
||||||
|
* `--num_epochs`: 轮数(Epoch)。
|
||||||
|
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||||
|
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
|
||||||
|
* `--weight_decay`: 权重衰减大小。
|
||||||
|
* `--task`: 训练任务,默认为 `sft`。
|
||||||
|
* 输出配置
|
||||||
|
* `--output_path`: 模型保存路径。
|
||||||
|
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||||
|
* `--save_steps`: 保存模型的训练步数间隔。
|
||||||
|
* LoRA 配置
|
||||||
|
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||||
|
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||||
|
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||||
|
* `--lora_checkpoint`: LoRA 检查点的路径。
|
||||||
|
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
|
||||||
|
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||||
|
* 梯度配置
|
||||||
|
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||||
|
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||||
|
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||||
|
* 分辨率配置
|
||||||
|
* `--height`: 图像/视频的高度。留空启用动态分辨率。
|
||||||
|
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
|
||||||
|
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
|
||||||
|
* `--num_frames`: 视频的帧数(仅视频生成模型)。
|
||||||
|
* JoyAI-Image 专有参数
|
||||||
|
* `--processor_path`: Processor 路径,用于处理文本和图像的编码器输入。
|
||||||
|
* `--initialize_model_on_cpu`: 是否在 CPU 上初始化模型,默认在加速设备上初始化。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
|
||||||
141
docs/zh/Model_Details/Stable-Diffusion-XL.md
Normal file
141
docs/zh/Model_Details/Stable-Diffusion-XL.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Stable Diffusion XL
|
||||||
|
|
||||||
|
Stable Diffusion XL (SDXL) 是由 Stability AI 开发的开源扩散式文本到图像生成模型,支持 1024x1024 分辨率的高质量文本到图像生成,采用双文本编码器(CLIP-L + CLIP-bigG)架构。
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 6GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型总览
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
|
||||||
|
|
||||||
|
## 模型推理
|
||||||
|
|
||||||
|
模型通过 `StableDiffusionXLPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||||
|
|
||||||
|
`StableDiffusionXLPipeline` 的推理输入参数包括:
|
||||||
|
|
||||||
|
* `prompt`: 文本提示词。
|
||||||
|
* `negative_prompt`: 负面提示词,默认为空字符串。
|
||||||
|
* `cfg_scale`: Classifier-Free Guidance 缩放系数,默认 5.0。
|
||||||
|
* `height`: 输出图像高度,默认 1024。
|
||||||
|
* `width`: 输出图像宽度,默认 1024。
|
||||||
|
* `seed`: 随机种子,默认不设置时使用随机种子。
|
||||||
|
* `rand_device`: 噪声生成设备,默认 "cpu"。
|
||||||
|
* `num_inference_steps`: 推理步数,默认 50。
|
||||||
|
* `guidance_rescale`: Guidance rescale 系数,默认 0.0。
|
||||||
|
* `progress_bar_cmd`: 进度条回调函数。
|
||||||
|
|
||||||
|
> `StableDiffusionXLPipeline` 需要双 tokenizer 配置(`tokenizer_config` 和 `tokenizer_2_config`),分别对应 CLIP-L 和 CLIP-bigG 文本编码器。
|
||||||
|
|
||||||
|
## 模型训练
|
||||||
|
|
||||||
|
stable_diffusion_xl 系列模型通过 `examples/stable_diffusion_xl/model_training/train.py` 进行训练,脚本的参数包括:
|
||||||
|
|
||||||
|
* 通用训练参数
|
||||||
|
* 数据集基础配置
|
||||||
|
* `--dataset_base_path`: 数据集的根目录。
|
||||||
|
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||||
|
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||||
|
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
|
||||||
|
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||||
|
* 模型加载配置
|
||||||
|
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||||
|
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
|
||||||
|
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
|
||||||
|
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
|
||||||
|
* 训练基础配置
|
||||||
|
* `--learning_rate`: 学习率。
|
||||||
|
* `--num_epochs`: 轮数(Epoch)。
|
||||||
|
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||||
|
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
|
||||||
|
* `--weight_decay`: 权重衰减大小。
|
||||||
|
* `--task`: 训练任务,默认为 `sft`。
|
||||||
|
* 输出配置
|
||||||
|
* `--output_path`: 模型保存路径。
|
||||||
|
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||||
|
* `--save_steps`: 保存模型的训练步数间隔。
|
||||||
|
* LoRA 配置
|
||||||
|
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||||
|
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||||
|
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||||
|
* `--lora_checkpoint`: LoRA 检查点的路径。
|
||||||
|
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
|
||||||
|
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||||
|
* 梯度配置
|
||||||
|
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||||
|
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||||
|
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||||
|
* 分辨率配置
|
||||||
|
* `--height`: 图像/视频的高度。留空启用动态分辨率。
|
||||||
|
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
|
||||||
|
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
|
||||||
|
* `--num_frames`: 视频的帧数(仅视频生成模型)。
|
||||||
|
* Stable Diffusion XL 专有参数
|
||||||
|
* `--tokenizer_path`: 第一个 Tokenizer 路径。
|
||||||
|
* `--tokenizer_2_path`: 第二个 Tokenizer 路径,默认为 `stabilityai/stable-diffusion-xl-base-1.0:tokenizer_2/`。
|
||||||
|
|
||||||
|
样例数据集下载:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
[stable-diffusion-xl-base-1.0 训练脚本](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)
|
||||||
|
|
||||||
|
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
|
||||||
138
docs/zh/Model_Details/Stable-Diffusion.md
Normal file
138
docs/zh/Model_Details/Stable-Diffusion.md
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
# Stable Diffusion
|
||||||
|
|
||||||
|
Stable Diffusion 是由 Stability AI 开发的开源扩散式文本到图像生成模型,支持 512x512 分辨率的文本到图像生成。
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 2GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型总览
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
|
||||||
|
|
||||||
|
## 模型推理
|
||||||
|
|
||||||
|
模型通过 `StableDiffusionPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||||
|
|
||||||
|
`StableDiffusionPipeline` 的推理输入参数包括:
|
||||||
|
|
||||||
|
* `prompt`: 文本提示词。
|
||||||
|
* `negative_prompt`: 负面提示词,默认为空字符串。
|
||||||
|
* `cfg_scale`: Classifier-Free Guidance 缩放系数,默认 7.5。
|
||||||
|
* `height`: 输出图像高度,默认 512。
|
||||||
|
* `width`: 输出图像宽度,默认 512。
|
||||||
|
* `seed`: 随机种子,默认不设置时使用随机种子。
|
||||||
|
* `rand_device`: 噪声生成设备,默认 "cpu"。
|
||||||
|
* `num_inference_steps`: 推理步数,默认 50。
|
||||||
|
* `eta`: DDIM 调度器的 eta 参数,默认 0.0。
|
||||||
|
* `guidance_rescale`: Guidance rescale 系数,默认 0.0。
|
||||||
|
* `progress_bar_cmd`: 进度条回调函数。
|
||||||
|
|
||||||
|
## 模型训练
|
||||||
|
|
||||||
|
stable_diffusion 系列模型通过 `examples/stable_diffusion/model_training/train.py` 进行训练,脚本的参数包括:
|
||||||
|
|
||||||
|
* 通用训练参数
|
||||||
|
* 数据集基础配置
|
||||||
|
* `--dataset_base_path`: 数据集的根目录。
|
||||||
|
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||||
|
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||||
|
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
|
||||||
|
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||||
|
* 模型加载配置
|
||||||
|
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||||
|
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
|
||||||
|
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
|
||||||
|
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
|
||||||
|
* 训练基础配置
|
||||||
|
* `--learning_rate`: 学习率。
|
||||||
|
* `--num_epochs`: 轮数(Epoch)。
|
||||||
|
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||||
|
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
|
||||||
|
* `--weight_decay`: 权重衰减大小。
|
||||||
|
* `--task`: 训练任务,默认为 `sft`。
|
||||||
|
* 输出配置
|
||||||
|
* `--output_path`: 模型保存路径。
|
||||||
|
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||||
|
* `--save_steps`: 保存模型的训练步数间隔。
|
||||||
|
* LoRA 配置
|
||||||
|
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||||
|
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||||
|
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||||
|
* `--lora_checkpoint`: LoRA 检查点的路径。
|
||||||
|
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
|
||||||
|
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||||
|
* 梯度配置
|
||||||
|
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||||
|
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||||
|
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||||
|
* 分辨率配置
|
||||||
|
* `--height`: 图像/视频的高度。留空启用动态分辨率。
|
||||||
|
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
|
||||||
|
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
|
||||||
|
* `--num_frames`: 视频的帧数(仅视频生成模型)。
|
||||||
|
* Stable Diffusion 专有参数
|
||||||
|
* `--tokenizer_path`: Tokenizer 路径,默认为 `AI-ModelScope/stable-diffusion-v1-5:tokenizer/`。
|
||||||
|
|
||||||
|
样例数据集下载:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
[stable-diffusion-v1-5 训练脚本](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)
|
||||||
|
|
||||||
|
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
|
||||||
@@ -31,6 +31,9 @@
|
|||||||
Model_Details/Anima
|
Model_Details/Anima
|
||||||
Model_Details/LTX-2
|
Model_Details/LTX-2
|
||||||
Model_Details/ERNIE-Image
|
Model_Details/ERNIE-Image
|
||||||
|
Model_Details/JoyAI-Image
|
||||||
|
Model_Details/Stable-Diffusion
|
||||||
|
Model_Details/Stable-Diffusion-XL
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|||||||
39
examples/joyai_image/model_inference/JoyAI-Image-Edit.py
Normal file
39
examples/joyai_image/model_inference/JoyAI-Image-Edit.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
# Download dataset
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = JoyAIImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth"),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use first sample from dataset
|
||||||
|
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||||
|
prompt = "将裙子改为粉色"
|
||||||
|
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||||
|
|
||||||
|
output = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=1,
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
output.save("output_joyai_edit.png")
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
# Download dataset
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||||
|
)
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe = JoyAIImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use first sample from dataset
|
||||||
|
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||||
|
prompt = "将裙子改为粉色"
|
||||||
|
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||||
|
|
||||||
|
output = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=0,
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
output.save("output_joyai_edit_low_vram.png")
|
||||||
35
examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh
Normal file
35
examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# Dataset: data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/
|
||||||
|
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
|
||||||
|
accelerate launch examples/joyai_image/model_training/train.py \
|
||||||
|
--dataset_base_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" \
|
||||||
|
--dataset_metadata_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 1 \
|
||||||
|
--model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/JoyAI-Image-Edit-full-cache" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--find_unused_parameters \
|
||||||
|
--data_file_keys "image,edit_image" \
|
||||||
|
--extra_inputs "edit_image" \
|
||||||
|
--task "sft:data_process"
|
||||||
|
|
||||||
|
accelerate launch --config_file examples/joyai_image/model_training/full/accelerate_config_zero3.yaml \
|
||||||
|
examples/joyai_image/model_training/train.py \
|
||||||
|
--dataset_base_path "./models/train/JoyAI-Image-Edit-full-cache" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/JoyAI-Image-Edit-full" \
|
||||||
|
--trainable_models "dit" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--find_unused_parameters \
|
||||||
|
--data_file_keys "image,edit_image" \
|
||||||
|
--extra_inputs "edit_image" \
|
||||||
|
--task "sft:train"
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: true
|
||||||
|
zero3_save_16bit_model: true
|
||||||
|
zero_stage: 3
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
39
examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh
Normal file
39
examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Dataset: data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/
|
||||||
|
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
|
||||||
|
accelerate launch examples/joyai_image/model_training/train.py \
|
||||||
|
--dataset_base_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" \
|
||||||
|
--dataset_metadata_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 1 \
|
||||||
|
--model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/JoyAI-Image-Edit-split-cache" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "img_attn_qkv,txt_attn_qkv,img_attn_proj,txt_attn_proj" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--find_unused_parameters \
|
||||||
|
--data_file_keys "image,edit_image" \
|
||||||
|
--extra_inputs "edit_image" \
|
||||||
|
--task "sft:data_process"
|
||||||
|
|
||||||
|
accelerate launch examples/joyai_image/model_training/train.py \
|
||||||
|
--dataset_base_path "./models/train/JoyAI-Image-Edit-split-cache" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/JoyAI-Image-Edit-lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "img_attn_qkv,txt_attn_qkv,img_attn_proj,txt_attn_proj" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--find_unused_parameters \
|
||||||
|
--data_file_keys "image,edit_image" \
|
||||||
|
--extra_inputs "edit_image" \
|
||||||
|
--task "sft:train"
|
||||||
138
examples/joyai_image/model_training/train.py
Normal file
138
examples/joyai_image/model_training/train.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
import torch, os, argparse, accelerate
|
||||||
|
from diffsynth.core import UnifiedDataset
|
||||||
|
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||||
|
from diffsynth.diffusion import *
|
||||||
|
from diffsynth.core.data.operators import *
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageTrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_paths=None, model_id_with_origin_paths=None,
|
||||||
|
processor_path=None,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
|
use_gradient_checkpointing=True,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
extra_inputs=None,
|
||||||
|
fp8_models=None,
|
||||||
|
offload_models=None,
|
||||||
|
device="cpu",
|
||||||
|
task="sft",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Load models
|
||||||
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||||
|
processor_config = ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/") if processor_path is None else ModelConfig(processor_path)
|
||||||
|
self.pipe = JoyAIImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, processor_config=processor_config)
|
||||||
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
|
||||||
|
# Training mode
|
||||||
|
self.switch_pipe_to_training_mode(
|
||||||
|
self.pipe, trainable_models,
|
||||||
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||||
|
preset_lora_path, preset_lora_model,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other configs
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
|
self.fp8_models = fp8_models
|
||||||
|
self.task = task
|
||||||
|
self.task_to_loss = {
|
||||||
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
|
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_pipeline_inputs(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
inputs_shared = {
|
||||||
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"rand_device": self.pipe.device,
|
||||||
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
|
}
|
||||||
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
def forward(self, data, inputs=None):
|
||||||
|
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||||
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||||
|
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def joyai_image_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="JoyAI-Image training.")
|
||||||
|
parser = add_general_config(parser)
|
||||||
|
parser = add_image_size_config(parser)
|
||||||
|
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor.")
|
||||||
|
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = joyai_image_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
accelerator = accelerate.Accelerator(
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||||
|
)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=args.dataset_repeat,
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=16,
|
||||||
|
width_division_factor=16,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model = JoyAIImageTrainingModule(
|
||||||
|
model_paths=args.model_paths,
|
||||||
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
processor_path=args.processor_path,
|
||||||
|
trainable_models=args.trainable_models,
|
||||||
|
lora_base_model=args.lora_base_model,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_checkpoint=args.lora_checkpoint,
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
preset_lora_model=args.preset_lora_model,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
extra_inputs=args.extra_inputs,
|
||||||
|
fp8_models=args.fp8_models,
|
||||||
|
offload_models=args.offload_models,
|
||||||
|
task=args.task,
|
||||||
|
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
||||||
|
)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
args.output_path,
|
||||||
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
|
)
|
||||||
|
launcher_map = {
|
||||||
|
"sft:data_process": launch_data_process_task,
|
||||||
|
"sft": launch_training_task,
|
||||||
|
"sft:train": launch_training_task,
|
||||||
|
}
|
||||||
|
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||||
|
from diffsynth import load_state_dict
|
||||||
|
|
||||||
|
pipe = JoyAIImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth"),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict = load_state_dict("models/train/JoyAI-Image-Edit_full/epoch-1.safetensors")
|
||||||
|
pipe.dit.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
prompt = "将裙子改为粉色"
|
||||||
|
edit_image = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").convert("RGB")
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=0,
|
||||||
|
num_inference_steps=50,
|
||||||
|
cfg_scale=5.0,
|
||||||
|
)
|
||||||
|
image.save("image_full.jpg")
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||||
|
|
||||||
|
pipe = JoyAIImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth"),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe.load_lora(pipe.dit, "models/train/JoyAI-Image-Edit-lora/epoch-4.safetensors")
|
||||||
|
|
||||||
|
prompt = "将裙子改为粉色"
|
||||||
|
edit_image = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").convert("RGB")
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=0,
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=5.0,
|
||||||
|
)
|
||||||
|
image.save("image_lora.jpg")
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/stable-diffusion-v1-5/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
|
||||||
|
accelerate launch examples/stable_diffusion/model_training/train.py \
|
||||||
|
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5 \
|
||||||
|
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5/metadata.csv \
|
||||||
|
--height 512 \
|
||||||
|
--width 512 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--trainable_models "unet" \
|
||||||
|
--remove_prefix_in_ckpt "pipe.unet." \
|
||||||
|
--output_path "./models/train/stable-diffusion-v1-5_full" \
|
||||||
|
--use_gradient_checkpointing
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/stable-diffusion-v1-5/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
|
||||||
|
accelerate launch examples/stable_diffusion/model_training/train.py \
|
||||||
|
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5 \
|
||||||
|
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5/metadata.csv \
|
||||||
|
--height 512 \
|
||||||
|
--width 512 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.unet." \
|
||||||
|
--output_path "./models/train/stable-diffusion-v1-5_lora" \
|
||||||
|
--lora_base_model "unet" \
|
||||||
|
--lora_target_modules "" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing
|
||||||
142
examples/stable_diffusion/model_training/train.py
Normal file
142
examples/stable_diffusion/model_training/train.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import torch, os, argparse, accelerate
|
||||||
|
from diffsynth.core import UnifiedDataset
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
|
||||||
|
from diffsynth.diffusion import *
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionTrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_paths=None, model_id_with_origin_paths=None,
|
||||||
|
tokenizer_path=None,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
|
use_gradient_checkpointing=True,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
extra_inputs=None,
|
||||||
|
fp8_models=None,
|
||||||
|
offload_models=None,
|
||||||
|
device="cpu",
|
||||||
|
task="sft",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Load models
|
||||||
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||||
|
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"))
|
||||||
|
self.pipe = StableDiffusionPipeline.from_pretrained(torch_dtype=torch.float32, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||||
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
|
||||||
|
# Training mode
|
||||||
|
self.switch_pipe_to_training_mode(
|
||||||
|
self.pipe, trainable_models,
|
||||||
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||||
|
preset_lora_path, preset_lora_model,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other configs
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
|
self.fp8_models = fp8_models
|
||||||
|
self.task = task
|
||||||
|
self.task_to_loss = {
|
||||||
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
|
"direct_distill:data_process": lambda pipe, *args: args,
|
||||||
|
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_pipeline_inputs(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
inputs_shared = {
|
||||||
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"rand_device": self.pipe.device,
|
||||||
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
|
}
|
||||||
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
def forward(self, data, inputs=None):
|
||||||
|
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||||
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||||
|
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser = add_general_config(parser)
|
||||||
|
parser = add_image_size_config(parser)
|
||||||
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
accelerator = accelerate.Accelerator(
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||||
|
)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=args.dataset_repeat,
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=32,
|
||||||
|
width_division_factor=32,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = StableDiffusionTrainingModule(
|
||||||
|
model_paths=args.model_paths,
|
||||||
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
tokenizer_path=args.tokenizer_path,
|
||||||
|
trainable_models=args.trainable_models,
|
||||||
|
lora_base_model=args.lora_base_model,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_checkpoint=args.lora_checkpoint,
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
preset_lora_model=args.preset_lora_model,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
extra_inputs=args.extra_inputs,
|
||||||
|
fp8_models=args.fp8_models,
|
||||||
|
offload_models=args.offload_models,
|
||||||
|
task=args.task,
|
||||||
|
device=accelerator.device,
|
||||||
|
)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
args.output_path,
|
||||||
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
|
)
|
||||||
|
launcher_map = {
|
||||||
|
"sft:data_process": launch_data_process_task,
|
||||||
|
"direct_distill:data_process": launch_data_process_task,
|
||||||
|
"sft": launch_training_task,
|
||||||
|
"sft:train": launch_training_task,
|
||||||
|
"direct_distill": launch_training_task,
|
||||||
|
"direct_distill:train": launch_training_task,
|
||||||
|
}
|
||||||
|
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
|
||||||
|
from diffsynth.core import load_state_dict
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
state_dict = load_state_dict("./models/train/stable-diffusion-v1-5_full/epoch-1.safetensors", torch_dtype=torch.float32)
|
||||||
|
pipe.unet.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a dog",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image_stable-diffusion-v1-5_full.jpg")
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.unet, "models/train/stable-diffusion-v1-5_lora/epoch-4.safetensors")
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a dog",
|
||||||
|
negative_prompt="blurry, low quality, deformed",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
seed=42,
|
||||||
|
rand_device="cuda",
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image_stable-diffusion-v1-5.jpg")
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float32,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float32,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float32,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.float32,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a photo of an astronaut riding a horse on mars",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/stable-diffusion-xl-base-1.0/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
|
||||||
|
accelerate launch examples/stable_diffusion_xl/model_training/train.py \
|
||||||
|
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0 \
|
||||||
|
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0/metadata.csv \
|
||||||
|
--height 1024 \
|
||||||
|
--width 1024 \
|
||||||
|
--dataset_repeat 10 \
|
||||||
|
--model_id_with_origin_paths "stabilityai/stable-diffusion-xl-base-1.0:text_encoder/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:text_encoder_2/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:unet/diffusion_pytorch_model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--trainable_models "unet" \
|
||||||
|
--remove_prefix_in_ckpt "pipe.unet." \
|
||||||
|
--output_path "./models/train/stable-diffusion-xl-base-1.0_full" \
|
||||||
|
--use_gradient_checkpointing
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/stable-diffusion-xl-base-1.0/*" --local_dir ./data/diffsynth_example_dataset
|
||||||
|
|
||||||
|
accelerate launch examples/stable_diffusion_xl/model_training/train.py \
|
||||||
|
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0 \
|
||||||
|
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0/metadata.csv \
|
||||||
|
--height 1024 \
|
||||||
|
--width 1024 \
|
||||||
|
--dataset_repeat 10 \
|
||||||
|
--model_id_with_origin_paths "stabilityai/stable-diffusion-xl-base-1.0:text_encoder/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:text_encoder_2/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:unet/diffusion_pytorch_model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.unet." \
|
||||||
|
--output_path "./models/train/stable-diffusion-xl-base-1.0_lora" \
|
||||||
|
--lora_base_model "unet" \
|
||||||
|
--lora_target_modules "" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing
|
||||||
144
examples/stable_diffusion_xl/model_training/train.py
Normal file
144
examples/stable_diffusion_xl/model_training/train.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
import torch, os, argparse, accelerate
|
||||||
|
from diffsynth.core import UnifiedDataset
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
|
||||||
|
from diffsynth.diffusion import *
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLTrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_paths=None, model_id_with_origin_paths=None,
|
||||||
|
tokenizer_path=None,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
|
use_gradient_checkpointing=True,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
extra_inputs=None,
|
||||||
|
fp8_models=None,
|
||||||
|
offload_models=None,
|
||||||
|
device="cpu",
|
||||||
|
task="sft",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Load models
|
||||||
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||||
|
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"))
|
||||||
|
tokenizer_2_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"))
|
||||||
|
self.pipe = StableDiffusionXLPipeline.from_pretrained(torch_dtype=torch.float32, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, tokenizer_2_config=tokenizer_2_config)
|
||||||
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
|
||||||
|
# Training mode
|
||||||
|
self.switch_pipe_to_training_mode(
|
||||||
|
self.pipe, trainable_models,
|
||||||
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||||
|
preset_lora_path, preset_lora_model,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other configs
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
|
self.fp8_models = fp8_models
|
||||||
|
self.task = task
|
||||||
|
self.task_to_loss = {
|
||||||
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
|
"direct_distill:data_process": lambda pipe, *args: args,
|
||||||
|
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_pipeline_inputs(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
inputs_shared = {
|
||||||
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"rand_device": self.pipe.device,
|
||||||
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
|
}
|
||||||
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
def forward(self, data, inputs=None):
|
||||||
|
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||||
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||||
|
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser = add_general_config(parser)
|
||||||
|
parser = add_image_size_config(parser)
|
||||||
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||||
|
parser.add_argument("--tokenizer_2_path", type=str, default=None, help="Path to tokenizer 2.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
accelerator = accelerate.Accelerator(
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||||
|
)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=args.dataset_repeat,
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=32,
|
||||||
|
width_division_factor=32,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = StableDiffusionXLTrainingModule(
|
||||||
|
model_paths=args.model_paths,
|
||||||
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
tokenizer_path=args.tokenizer_path,
|
||||||
|
trainable_models=args.trainable_models,
|
||||||
|
lora_base_model=args.lora_base_model,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_checkpoint=args.lora_checkpoint,
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
preset_lora_model=args.preset_lora_model,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
extra_inputs=args.extra_inputs,
|
||||||
|
fp8_models=args.fp8_models,
|
||||||
|
offload_models=args.offload_models,
|
||||||
|
task=args.task,
|
||||||
|
device=accelerator.device,
|
||||||
|
)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
args.output_path,
|
||||||
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
|
)
|
||||||
|
launcher_map = {
|
||||||
|
"sft:data_process": launch_data_process_task,
|
||||||
|
"direct_distill:data_process": launch_data_process_task,
|
||||||
|
"sft": launch_training_task,
|
||||||
|
"sft:train": launch_training_task,
|
||||||
|
"direct_distill": launch_training_task,
|
||||||
|
"direct_distill:train": launch_training_task,
|
||||||
|
}
|
||||||
|
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
|
||||||
|
from diffsynth.core import load_state_dict
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
)
|
||||||
|
state_dict = load_state_dict("./models/train/stable-diffusion-xl-base-1.0_full/epoch-1.safetensors", torch_dtype=torch.float32)
|
||||||
|
pipe.unet.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a dog",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=7.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image_stable-diffusion-xl-base-1.0_full.jpg")
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.core import ModelConfig
|
||||||
|
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||||
|
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.unet, "models/train/stable-diffusion-xl-base-1.0_lora/epoch-4.safetensors")
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a dog",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=7.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
)
|
||||||
|
image.save("image_stable-diffusion-xl-base-1.0.jpg")
|
||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "diffsynth"
|
name = "diffsynth"
|
||||||
version = "2.0.7"
|
version = "2.0.9"
|
||||||
description = "Enjoy the magic of Diffusion models!"
|
description = "Enjoy the magic of Diffusion models!"
|
||||||
authors = [{name = "ModelScope Team"}]
|
authors = [{name = "ModelScope Team"}]
|
||||||
license = {text = "Apache-2.0"}
|
license = {text = "Apache-2.0"}
|
||||||
|
|||||||
Reference in New Issue
Block a user