Support JoyAI-Image-Edit (#1393)

* auto intergrate joyimage model

* joyimage pipeline

* train

* ready

* styling

* joyai-image docs

* update readme

* pr review
This commit is contained in:
Hong Zhang
2026-04-15 16:57:11 +08:00
committed by GitHub
parent 8f18e24597
commit 079e51c9f3
21 changed files with 2040 additions and 123 deletions

200
README.md
View File

@@ -34,6 +34,8 @@ We believe that a well-developed open-source code framework can lower the thresh
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/).
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
@@ -598,6 +600,143 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
</details>
#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
</details>
#### JoyAI-Image: [/docs/en/Model_Details/JoyAI-Image.md](/docs/en/Model_Details/JoyAI-Image.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 4GB VRAM.
```python
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")
```
</details>
<details>
<summary>Examples</summary>
Example code for JoyAI-Image is available at: [/examples/joyai_image/](/examples/joyai_image/)
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
</details>
### Video Synthesis
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
@@ -877,67 +1016,6 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
</details>
#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
</details>
## Innovative Achievements
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.

View File

@@ -34,6 +34,8 @@ DiffSynth 目前包括两个开源项目:
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。
@@ -598,6 +600,143 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
</details>
#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
</details>
<details>
<summary>示例代码</summary>
ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/)
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|-|-|-|-|-|-|-|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
</details>
#### JoyAI-Image: [/docs/zh/Model_Details/JoyAI-Image.md](/docs/zh/Model_Details/JoyAI-Image.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。
```python
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")
```
</details>
<details>
<summary>示例代码</summary>
JoyAI-Image 的示例代码位于:[/examples/joyai_image/](/examples/joyai_image/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
</details>
### 视频生成模型
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
@@ -877,67 +1016,6 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
</details>
#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
</details>
<details>
<summary>示例代码</summary>
ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/)
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|-|-|-|-|-|-|-|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
</details>
## 创新成果
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。

View File

@@ -900,4 +900,20 @@ mova_series = [
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series
joyai_image_series = [
{
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth")
"model_hash": "56592ddfd7d0249d3aa527d24161a863",
"model_name": "joyai_image_dit",
"model_class": "diffsynth.models.joyai_image_dit.JoyAIImageDiT",
},
{
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model-*.safetensors")
"model_hash": "2d11bf14bba8b4e87477c8199a895403",
"model_name": "joyai_image_text_encoder",
"model_class": "diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.joyai_image_text_encoder.JoyAIImageTextEncoderStateDictConverter",
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series

View File

@@ -279,6 +279,22 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.joyai_image_dit.Transformer3DModel": {
"diffsynth.models.joyai_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.joyai_image_dit.ModulateWan": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionModel": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
}
def QwenImageTextEncoder_Module_Map_Updater():

View File

@@ -159,6 +159,18 @@ class FlowMatchScheduler():
timesteps[timestep_id] = timestep
return sigmas, timesteps
@staticmethod
def set_timesteps_joyai_image(num_inference_steps=100, denoising_strength=1.0, shift=None):
sigma_min = 0.0
sigma_max = 1.0
shift = 4.0 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
num_train_timesteps = 1000

View File

@@ -0,0 +1,636 @@
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from ..core.attention import attention_forward
from ..core.gradient import gradient_checkpoint_forward
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
) -> torch.Tensor:
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
emb = scale * emb
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
return get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = nn.SiLU()
time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
self.post_act = nn.SiLU() if post_act_fn == "silu" else None
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class PixArtAlphaTextProjection(nn.Module):
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
else:
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj(hidden_states)
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
return hidden_states
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
# Build activation + projection matching diffusers pattern
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
elif activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
else:
act_fn = GELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
self.net.append(act_fn)
self.net.append(nn.Dropout(dropout))
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
if len(args) == 0:
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = _to_tuple(args[1], dim=dim)
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij")
grid = torch.stack(grid, dim=0)
return grid
def reshape_for_broadcast(freqs_cis, x, head_first=False):
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
if head_first:
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1])
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
if head_first:
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(xq, xk, freqs_cis, head_first=False):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
return xq_out, xk_out
def get_1d_rotary_pos_embed(dim, pos, theta=10000.0, use_real=False, theta_rescale_factor=1.0, interpolation_factor=1.0):
if isinstance(pos, int):
pos = torch.arange(pos).float()
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs = torch.outer(pos * interpolation_factor, freqs)
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1)
freqs_sin = freqs.sin().repeat_interleave(2, dim=1)
return freqs_cos, freqs_sin
else:
return torch.polar(torch.ones_like(freqs), freqs)
def get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000.0, use_real=False,
txt_rope_size=None, theta_rescale_factor=1.0, interpolation_factor=1.0):
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))
if isinstance(theta_rescale_factor, (int, float)):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
if isinstance(interpolation_factor, (int, float)):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i], grid[i].reshape(-1), theta,
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
)
embs.append(emb)
if use_real:
vis_emb = (torch.cat([emb[0] for emb in embs], dim=1), torch.cat([emb[1] for emb in embs], dim=1))
else:
vis_emb = torch.cat(embs, dim=1)
if txt_rope_size is not None:
embs_txt = []
vis_max_ids = grid.view(-1).max().item()
grid_txt = torch.arange(txt_rope_size) + vis_max_ids + 1
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i], grid_txt, theta,
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
)
embs_txt.append(emb)
if use_real:
txt_emb = (torch.cat([emb[0] for emb in embs_txt], dim=1), torch.cat([emb[1] for emb in embs_txt], dim=1))
else:
txt_emb = torch.cat(embs_txt, dim=1)
else:
txt_emb = None
return vis_emb, txt_emb
class ModulateWan(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor
factory_kwargs = {"dtype": dtype, "device": device}
self.modulate_table = nn.Parameter(
torch.zeros(1, factor, hidden_size, **factory_kwargs) / hidden_size**0.5,
requires_grad=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)]
def modulate(x, shift=None, scale=None):
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False):
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def load_modulation(modulate_type: str, hidden_size: int, factor: int, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
if modulate_type == 'wanx':
return ModulateWan(hidden_size, factor, **factory_kwargs)
raise ValueError(f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.")
class RMSNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with separate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
dit_modulation_type: Optional[str] = "wanx",
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.dit_modulation_type = dit_modulation_type
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = load_modulation(
modulate_type=self.dit_modulation_type,
hidden_size=hidden_size, factor=6, **factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
self.txt_mod = load_modulation(
modulate_type=self.dit_modulation_type,
hidden_size=hidden_size, factor=6, **factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
vis_freqs_cis: tuple = None,
txt_freqs_cis: tuple = None,
attn_kwargs: Optional[dict] = {},
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift, img_mod1_scale, img_mod1_gate,
img_mod2_shift, img_mod2_scale, img_mod2_gate,
) = self.img_mod(vec)
(
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate,
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate,
) = self.txt_mod(vec)
img_modulated = self.img_norm1(img)
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
img_q, img_k = img_qq, img_kk
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
raise NotImplementedError("RoPE text is not supported for inference")
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
# Use DiffSynth unified attention
attn_out = attention_forward(
q, k, v,
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
)
attn_out = attn_out.flatten(2, 3)
img_attn, txt_attn = attn_out[:, : img.shape[1]], attn_out[:, img.shape[1]:]
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
img = img + apply_gate(
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
gate=img_mod2_gate,
)
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
txt = txt + apply_gate(
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
gate=txt_mod2_gate,
)
return img, txt
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
return temb, timestep_proj, encoder_hidden_states
class JoyAIImageDiT(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
patch_size: list = [1, 2, 2],
in_channels: int = 16,
out_channels: int = 16,
hidden_size: int = 4096,
heads_num: int = 32,
text_states_dim: int = 4096,
mlp_width_ratio: float = 4.0,
mm_double_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
rope_type: str = 'rope',
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
dit_modulation_type: str = "wanx",
theta: int = 10000,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.patch_size = patch_size
self.hidden_size = hidden_size
self.heads_num = heads_num
self.rope_dim_list = rope_dim_list
self.dit_modulation_type = dit_modulation_type
self.mm_double_blocks_depth = mm_double_blocks_depth
self.rope_type = rope_type
self.theta = theta
factory_kwargs = {"device": device, "dtype": dtype}
if hidden_size % heads_num != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
self.condition_embedder = WanTimeTextImageEmbedding(
dim=hidden_size,
time_freq_dim=256,
time_proj_dim=hidden_size * 6,
text_embed_dim=text_states_dim,
)
self.double_blocks = nn.ModuleList([
MMDoubleStreamBlock(
self.hidden_size, self.heads_num,
mlp_width_ratio=mlp_width_ratio,
dit_modulation_type=self.dit_modulation_type,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
])
self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size), **factory_kwargs)
def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None):
target_ndim = 3
if len(vis_rope_size) != target_ndim:
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
head_dim = self.hidden_size // self.heads_num
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim
vis_freqs, txt_freqs = get_nd_rotary_pos_embed(
rope_dim_list, vis_rope_size,
txt_rope_size=txt_rope_size if self.rope_type == 'mrope' else None,
theta=self.theta, use_real=True, theta_rescale_factor=1,
)
return vis_freqs, txt_freqs
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states_mask: torch.Tensor = None,
return_dict: bool = True,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
is_multi_item = (len(hidden_states.shape) == 6)
num_items = 0
if is_multi_item:
num_items = hidden_states.shape[1]
if num_items > 1:
assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1"
hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1)
hidden_states = rearrange(hidden_states, 'b n c t h w -> b c (n t) h w')
batch_size, _, ot, oh, ow = hidden_states.shape
tt, th, tw = ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2]
if encoder_hidden_states_mask is None:
encoder_hidden_states_mask = torch.ones(
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]),
dtype=torch.bool,
).to(encoder_hidden_states.device)
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
temb, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
if vec.shape[-1] > self.hidden_size:
vec = vec.unflatten(1, (6, -1))
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed(
vis_rope_size=(tt, th, tw),
txt_rope_size=txt_seq_len if self.rope_type == 'mrope' else None,
)
for block in self.double_blocks:
img, txt = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
img=img, txt=txt, vec=vec,
vis_freqs_cis=vis_freqs_cis, txt_freqs_cis=txt_freqs_cis,
attn_kwargs={},
)
img_len = img.shape[1]
x = torch.cat((img, txt), 1)
img = x[:, :img_len, ...]
img = self.proj_out(self.norm_out(img))
img = self.unpatchify(img, tt, th, tw)
if is_multi_item:
img = rearrange(img, 'b c (n t) h w -> b n c t h w', n=num_items)
if num_items > 1:
img = torch.cat([img[:, 1:], img[:, :1]], dim=1)
return img
def unpatchify(self, x, t, h, w):
c = self.out_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = torch.einsum("nthwopqc->nctohpwq", x)
return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))

View File

@@ -0,0 +1,82 @@
import torch
from typing import Optional
class JoyAIImageTextEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
config = Qwen3VLConfig(
text_config={
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 12288,
"max_position_embeddings": 262144,
"model_type": "qwen3_vl_text",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-6,
"rope_scaling": {
"mrope_interleaved": True,
"mrope_section": [24, 20, 20],
"rope_type": "default",
},
"rope_theta": 5000000,
"use_cache": True,
"vocab_size": 151936,
},
vision_config={
"deepstack_visual_indexes": [8, 16, 24],
"depth": 27,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4304,
"model_type": "qwen3_vl",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 4096,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=False,
)
self.model = Qwen3VLForConditionalGeneration(config)
self.config = config
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
**kwargs,
):
pre_norm_output = [None]
def hook_fn(module, args, kwargs_output=None):
pre_norm_output[0] = args[0]
self.model.model.language_model.norm.register_forward_hook(hook_fn)
_ = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
return pre_norm_output[0]

View File

@@ -0,0 +1,282 @@
import torch
from PIL import Image
from typing import Union, Optional
from tqdm import tqdm
from einops import rearrange
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.joyai_image_dit import JoyAIImageDiT
from ..models.joyai_image_text_encoder import JoyAIImageTextEncoder
from ..models.wan_video_vae import WanVideoVAE
class JoyAIImagePipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("Wan")
self.text_encoder: JoyAIImageTextEncoder = None
self.dit: JoyAIImageDiT = None
self.vae: WanVideoVAE = None
self.processor = None
self.in_iteration_models = ("dit",)
self.units = [
JoyAIImageUnit_ShapeChecker(),
JoyAIImageUnit_EditImageEmbedder(),
JoyAIImageUnit_PromptEmbedder(),
JoyAIImageUnit_NoiseInitializer(),
JoyAIImageUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_joyai_image
self.compilable_models = ["dit"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
# Processor
processor_config: ModelConfig = None,
# Optional
vram_limit: float = None,
):
pipe = JoyAIImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("joyai_image_text_encoder")
pipe.dit = model_pool.fetch_model("joyai_image_dit")
pipe.vae = model_pool.fetch_model("wan_video_vae")
if processor_config is not None:
processor_config.download_if_necessary()
from transformers import AutoProcessor
pipe.processor = AutoProcessor.from_pretrained(processor_config.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 5.0,
# Image
edit_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
# Steps
max_sequence_length: int = 4096,
num_inference_steps: int = 30,
# Tiling
tiled: Optional[bool] = False,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# Scheduler
shift: Optional[float] = 4.0,
# Progress bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=shift)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"edit_image": edit_image,
"denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "max_sequence_length": max_sequence_length,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
}
# Unit chain
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
latents = rearrange(inputs_shared["latents"], "b n c f h w -> (b n) c f h w")
image = self.vae.decode(latents, device=self.device)[0]
image = self.vae_output_to_image(image, pattern="C 1 H W")
self.load_models_to_device([])
return image
class JoyAIImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: "JoyAIImagePipeline", height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class JoyAIImageUnit_PromptEmbedder(PipelineUnit):
prompt_template_encode = {
'image':
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
'multiple_images':
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n",
'video':
"<|im_start|>system\n \\nDescribe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
}
prompt_template_encode_start_idx = {'image': 34, 'multiple_images': 34, 'video': 91}
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt", "positive": "positive"},
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
input_params=("edit_image", "max_sequence_length"),
output_params=("prompt_embeds", "prompt_embeds_mask"),
onload_model_names=("joyai_image_text_encoder",),
)
def process(self, pipe: "JoyAIImagePipeline", prompt, positive, edit_image, max_sequence_length):
pipe.load_models_to_device(self.onload_model_names)
has_image = edit_image is not None
if has_image:
prompt_embeds, prompt_embeds_mask = self._encode_with_image(pipe, prompt, edit_image, max_sequence_length)
else:
prompt_embeds, prompt_embeds_mask = self._encode_text_only(pipe, prompt, max_sequence_length)
return {"prompt_embeds": prompt_embeds, "prompt_embeds_mask": prompt_embeds_mask}
def _encode_with_image(self, pipe, prompt, edit_image, max_sequence_length):
template = self.prompt_template_encode['multiple_images']
drop_idx = self.prompt_template_encode_start_idx['multiple_images']
image_tokens = '<image>\n'
prompt = f"<|im_start|>user\n{image_tokens}{prompt}<|im_end|>\n"
prompt = prompt.replace('<image>\n', '<|vision_start|><|image_pad|><|vision_end|>')
prompt = template.format(prompt)
inputs = pipe.processor(text=[prompt], images=[edit_image], padding=True, return_tensors="pt").to(pipe.device)
last_hidden_states = pipe.text_encoder(**inputs)
prompt_embeds = last_hidden_states[:, drop_idx:]
prompt_embeds_mask = inputs['attention_mask'][:, drop_idx:]
if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length:
prompt_embeds = prompt_embeds[:, -max_sequence_length:, :]
prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:]
return prompt_embeds, prompt_embeds_mask
def _encode_text_only(self, pipe, prompt, max_sequence_length):
# TODO: may support for text-only encoding in the future.
raise NotImplementedError("Text-only encoding is not implemented yet. Please provide edit_image for now.")
return prompt_embeds, encoder_attention_mask
class JoyAIImageUnit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "tiled", "tile_size", "tile_stride", "height", "width"),
output_params=("ref_latents", "num_items", "is_multi_item"),
onload_model_names=("wan_video_vae",),
)
def process(self, pipe: "JoyAIImagePipeline", edit_image, tiled, tile_size, tile_stride, height, width):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
# Resize edit image to match target dimensions (from ShapeChecker) to ensure ref_latents matches latents
edit_image = edit_image.resize((width, height), Image.LANCZOS)
images = [pipe.preprocess_image(edit_image).transpose(0, 1)]
latents = pipe.vae.encode(images, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
ref_vae = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=1).to(device=pipe.device, dtype=pipe.torch_dtype)
return {"ref_latents": ref_vae, "edit_image": edit_image}
class JoyAIImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("seed", "height", "width", "rand_device"),
output_params=("noise"),
)
def process(self, pipe: "JoyAIImagePipeline", seed, height, width, rand_device):
latent_h = height // pipe.vae.upsampling_factor
latent_w = width // pipe.vae.upsampling_factor
shape = (1, 1, pipe.vae.z_dim, 1, latent_h, latent_w)
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class JoyAIImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe: JoyAIImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
if input_image is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
if isinstance(input_image, Image.Image):
input_image = [input_image]
input_image = [pipe.preprocess_image(img).transpose(0, 1) for img in input_image]
latents = pipe.vae.encode(input_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
input_latents = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(input_image)))
return {"latents": noise, "input_latents": input_latents}
def model_fn_joyai_image(
dit,
latents,
timestep,
prompt_embeds,
prompt_embeds_mask,
ref_latents=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
img = torch.cat([ref_latents, latents], dim=1) if ref_latents is not None else latents
img = dit(
hidden_states=img,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
img = img[:, -latents.size(1):]
return img

View File

@@ -0,0 +1,20 @@
def JoyAIImageTextEncoderStateDictConverter(state_dict):
"""Convert HuggingFace Qwen3VL checkpoint keys to DiffSynth wrapper keys.
Mapping (checkpoint -> wrapper):
- lm_head.weight -> model.lm_head.weight
- model.language_model.* -> model.model.language_model.*
- model.visual.* -> model.model.visual.*
"""
state_dict_ = {}
for key in state_dict:
if key == "lm_head.weight":
new_key = "model.lm_head.weight"
elif key.startswith("model.language_model."):
new_key = "model.model." + key[len("model."):]
elif key.startswith("model.visual."):
new_key = "model.model." + key[len("model."):]
else:
new_key = key
state_dict_[new_key] = state_dict[key]
return state_dict_

View File

@@ -0,0 +1,154 @@
# JoyAI-Image
JoyAI-Image is a unified multi-modal foundation model open-sourced by JD.com, supporting image understanding, text-to-image generation, and instruction-guided image editing.
## Installation
Before performing model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Running the following code will load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 4GB VRAM.
```python
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
## Model Inference
The model is loaded via `JoyAIImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
The input parameters for `JoyAIImagePipeline` inference include:
* `prompt`: Text prompt describing the desired image editing effect.
* `negative_prompt`: Negative prompt specifying what should not appear in the result, defaults to empty string.
* `cfg_scale`: Classifier-free guidance scale factor, defaults to 5.0. Higher values make the output more closely follow the prompt.
* `edit_image`: Image to be edited.
* `denoising_strength`: Denoising strength controlling how much the input image is repainted, defaults to 1.0.
* `height`: Height of the output image, defaults to 1024. Must be divisible by 16.
* `width`: Width of the output image, defaults to 1024. Must be divisible by 16.
* `seed`: Random seed for reproducibility. Set to `None` for random seed.
* `max_sequence_length`: Maximum sequence length for the text encoder, defaults to 4096.
* `num_inference_steps`: Number of inference steps, defaults to 30. More steps typically yield better quality.
* `tiled`: Whether to enable tiling for reduced VRAM usage, defaults to False.
* `tile_size`: Tile size, defaults to (30, 52).
* `tile_stride`: Tile stride, defaults to (15, 26).
* `shift`: Shift parameter for the scheduler, controlling the Flow Match scheduling curve, defaults to 4.0.
* `progress_bar_cmd`: Progress bar display mode, defaults to tqdm.
## Model Training
Models in the joyai_image series are trained uniformly via `examples/joyai_image/model_training/train.py`. The script parameters include:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Path to the dataset metadata file.
* `--dataset_repeat`: Number of dataset repeats per epoch.
* `--dataset_num_workers`: Number of processes per DataLoader.
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths to load models from, in JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
* Basic Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
* `--weight_decay`: Weight decay magnitude.
* `--task`: Training task, defaults to `sft`.
* Output Configuration
* `--output_path`: Path to save the model.
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
* `--save_steps`: Interval in training steps to save the model.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to LoRA checkpoint.
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Resolution Configuration
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
* `--num_frames`: Number of frames for video (video generation models only).
* JoyAI-Image Specific Parameters
* `--processor_path`: Path to the processor for processing text and image encoder inputs.
* `--initialize_model_on_cpu`: Whether to initialize models on CPU. By default, models are initialized on the accelerator device.
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -31,6 +31,7 @@ Welcome to DiffSynth-Studio's Documentation
Model_Details/Anima
Model_Details/LTX-2
Model_Details/ERNIE-Image
Model_Details/JoyAI-Image
.. toctree::
:maxdepth: 2

View File

@@ -0,0 +1,154 @@
# JoyAI-Image
JoyAI-Image 是京东开源的统一多模态基础模型,支持图像理解、文生图生成和指令引导的图像编辑。
## 安装
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
## 快速开始
运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。
```python
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")
```
## 模型总览
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
## 模型推理
模型通过 `JoyAIImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
`JoyAIImagePipeline` 推理的输入参数包括:
* `prompt`: 文本提示词,用于描述期望的图像编辑效果。
* `negative_prompt`: 负向提示词,指定不希望出现在结果中的内容,默认为空字符串。
* `cfg_scale`: 分类器自由引导的缩放系数,默认为 5.0。值越大,生成结果越贴近 prompt 描述。
* `edit_image`: 待编辑的单张图像。
* `denoising_strength`: 降噪强度,控制输入图像被重绘的程度,默认为 1.0。
* `height`: 输出图像的高度,默认为 1024。需能被 16 整除。
* `width`: 输出图像的宽度,默认为 1024。需能被 16 整除。
* `seed`: 随机种子,用于控制生成的可复现性。设为 `None` 时使用随机种子。
* `max_sequence_length`: 文本编码器处理的最大序列长度,默认为 4096。
* `num_inference_steps`: 推理步数,默认为 30。步数越多生成质量通常越好。
* `tiled`: 是否启用分块处理,用于降低显存占用,默认为 False。
* `tile_size`: 分块大小,默认为 (30, 52)。
* `tile_stride`: 分块步幅,默认为 (15, 26)。
* `shift`: 调度器的 shift 参数,用于控制 Flow Match 的调度曲线,默认为 4.0。
* `progress_bar_cmd`: 进度条显示方式,默认为 tqdm。
## 模型训练
joyai_image 系列模型统一通过 `examples/joyai_image/model_training/train.py` 进行训练,脚本的参数包括:
* 通用训练参数
* 数据集基础配置
* `--dataset_base_path`: 数据集的根目录。
* `--dataset_metadata_path`: 数据集的元数据文件路径。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
* 模型加载配置
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
* 训练基础配置
* `--learning_rate`: 学习率。
* `--num_epochs`: 轮数Epoch
* `--trainable_models`: 可训练的模型,例如 `dit``vae``text_encoder`
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
* `--weight_decay`: 权重衰减大小。
* `--task`: 训练任务,默认为 `sft`
* 输出配置
* `--output_path`: 模型保存路径。
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
* `--save_steps`: 保存模型的训练步数间隔。
* LoRA 配置
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪些层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`
* 梯度配置
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
* `--gradient_accumulation_steps`: 梯度累积步数。
* 分辨率配置
* `--height`: 图像/视频的高度。留空启用动态分辨率。
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
* `--num_frames`: 视频的帧数(仅视频生成模型)。
* JoyAI-Image 专有参数
* `--processor_path`: Processor 路径,用于处理文本和图像的编码器输入。
* `--initialize_model_on_cpu`: 是否在 CPU 上初始化模型,默认在加速设备上初始化。
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。

View File

@@ -31,6 +31,7 @@
Model_Details/Anima
Model_Details/LTX-2
Model_Details/ERNIE-Image
Model_Details/JoyAI-Image
.. toctree::
:maxdepth: 2

View File

@@ -0,0 +1,39 @@
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth"),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=1,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit.png")

View File

@@ -0,0 +1,51 @@
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")

View File

@@ -0,0 +1,35 @@
# Dataset: data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/joyai_image/model_training/train.py \
--dataset_base_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" \
--dataset_metadata_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv" \
--max_pixels 1048576 \
--dataset_repeat 1 \
--model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/JoyAI-Image-Edit-full-cache" \
--use_gradient_checkpointing \
--find_unused_parameters \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--task "sft:data_process"
accelerate launch --config_file examples/joyai_image/model_training/full/accelerate_config_zero3.yaml \
examples/joyai_image/model_training/train.py \
--dataset_base_path "./models/train/JoyAI-Image-Edit-full-cache" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/JoyAI-Image-Edit-full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--task "sft:train"

View File

@@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,39 @@
# Dataset: data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/joyai_image/model_training/train.py \
--dataset_base_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" \
--dataset_metadata_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv" \
--max_pixels 1048576 \
--dataset_repeat 1 \
--model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/JoyAI-Image-Edit-split-cache" \
--lora_base_model "dit" \
--lora_target_modules "img_attn_qkv,txt_attn_qkv,img_attn_proj,txt_attn_proj" \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--task "sft:data_process"
accelerate launch examples/joyai_image/model_training/train.py \
--dataset_base_path "./models/train/JoyAI-Image-Edit-split-cache" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/JoyAI-Image-Edit-lora" \
--lora_base_model "dit" \
--lora_target_modules "img_attn_qkv,txt_attn_qkv,img_attn_proj,txt_attn_proj" \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--task "sft:train"

View File

@@ -0,0 +1,138 @@
import torch, os, argparse, accelerate
from diffsynth.core import UnifiedDataset
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
from diffsynth.diffusion import *
from diffsynth.core.data.operators import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class JoyAIImageTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
processor_path=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
fp8_models=None,
offload_models=None,
device="cpu",
task="sft",
):
super().__init__()
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
processor_config = ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/") if processor_path is None else ModelConfig(processor_path)
self.pipe = JoyAIImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, processor_config=processor_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
preset_lora_path, preset_lora_model,
task=task,
)
# Other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
}
def get_pipeline_inputs(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"cfg_scale": 1,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
if inputs is None: inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
loss = self.task_to_loss[self.task](self.pipe, *inputs)
return loss
def joyai_image_parser():
parser = argparse.ArgumentParser(description="JoyAI-Image training.")
parser = add_general_config(parser)
parser = add_image_size_config(parser)
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor.")
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
return parser
if __name__ == "__main__":
parser = joyai_image_parser()
args = parser.parse_args()
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
)
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","),
main_data_operator=UnifiedDataset.default_image_operator(
base_path=args.dataset_base_path,
max_pixels=args.max_pixels,
height=args.height,
width=args.width,
height_division_factor=16,
width_division_factor=16,
),
)
model = JoyAIImageTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
processor_path=args.processor_path,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
preset_lora_path=args.preset_lora_path,
preset_lora_model=args.preset_lora_model,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
fp8_models=args.fp8_models,
offload_models=args.offload_models,
task=args.task,
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
launcher_map = {
"sft:data_process": launch_data_process_task,
"sft": launch_training_task,
"sft:train": launch_training_task,
}
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)

View File

@@ -0,0 +1,32 @@
import torch
from PIL import Image
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
from diffsynth import load_state_dict
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth"),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
)
state_dict = load_state_dict("models/train/JoyAI-Image-Edit_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "将裙子改为粉色"
edit_image = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").convert("RGB")
image = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=50,
cfg_scale=5.0,
)
image.save("image_full.jpg")

View File

@@ -0,0 +1,30 @@
import torch
from PIL import Image
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth"),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
)
pipe.load_lora(pipe.dit, "models/train/JoyAI-Image-Edit-lora/epoch-4.safetensors")
prompt = "将裙子改为粉色"
edit_image = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").convert("RGB")
image = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
image.save("image_lora.jpg")