Compare commits

..

2 Commits
webui ... main

Author SHA1 Message Date
Hong Zhang
960d8c62c0 Support ERNIE-Image (#1389)
* ernie-image pipeline

* ernie-image inference and training

* style fix

* ernie docs

* lowvram

* final style fix

* pr-review

* pr-fix round2

* set uniform training weight

* fix

* update lowvram docs
2026-04-13 14:57:10 +08:00
Zhongjie Duan
f77b6357c5 add Discord link in README (#1390)
* add discord link

* add discord link
2026-04-13 14:50:48 +08:00
25 changed files with 1483 additions and 293 deletions

View File

@@ -7,6 +7,7 @@
[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues) [![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/) [![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/) [![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
[![Discord](https://badgen.net//discord/members/Mm9suEeUDc)](https://discord.gg/Mm9suEeUDc)
[切换到中文版](./README_zh.md) [切换到中文版](./README_zh.md)
@@ -32,6 +33,7 @@ We believe that a well-developed open-source code framework can lower the thresh
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update. > DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available. - **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/). - **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
@@ -875,6 +877,66 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
</details> </details>
#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [baidu/ERNIE-Image](https://www.modelscope.cn/models/baidu/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[baidu/ERNIE-Image: T2I](https://www.modelscope.cn/models/baidu/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py)|
</details>
## Innovative Achievements ## Innovative Achievements
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements. DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
@@ -1029,3 +1091,9 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
</details> </details>
## Contact Us
|Discordhttps://discord.gg/Mm9suEeUDc|
|-|
|<img width="160" height="160" alt="Image" src="https://github.com/user-attachments/assets/29bdc97b-e35d-4fea-88d6-32e35182e458" />|

View File

@@ -7,6 +7,7 @@
[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues) [![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/) [![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/) [![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
[![Discord](https://badgen.net//discord/members/Mm9suEeUDc)](https://discord.gg/Mm9suEeUDc)
[Switch to English](./README.md) [Switch to English](./README.md)
@@ -876,6 +877,66 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
</details> </details>
#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [baidu/ERNIE-Image](https://www.modelscope.cn/models/baidu/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
</details>
<details>
<summary>示例代码</summary>
ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/)
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|-|-|-|-|-|-|-|
|[baidu/ERNIE-Image: T2I](https://www.modelscope.cn/models/baidu/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py)|
</details>
## 创新成果 ## 创新成果
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。 DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
@@ -1032,3 +1093,9 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
</details> </details>
## 联系我们
|Discordhttps://discord.gg/Mm9suEeUDc|
|-|
|<img width="160" height="160" alt="Image" src="https://github.com/user-attachments/assets/29bdc97b-e35d-4fea-88d6-32e35182e458" />|

View File

@@ -541,6 +541,22 @@ flux2_series = [
}, },
] ]
ernie_image_series = [
{
# Example: ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "584c13713849f1af4e03d5f1858b8b7b",
"model_name": "ernie_image_dit",
"model_class": "diffsynth.models.ernie_image_dit.ErnieImageDiT",
},
{
# Example: ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "404ed9f40796a38dd34c1620f1920207",
"model_name": "ernie_image_text_encoder",
"model_class": "diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ernie_image_text_encoder.ErnieImageTextEncoderStateDictConverter",
},
]
z_image_series = [ z_image_series = [
{ {
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors") # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
@@ -884,4 +900,4 @@ mova_series = [
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge", "model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
}, },
] ]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series

View File

@@ -267,6 +267,18 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
}, },
"diffsynth.models.ernie_image_dit.ErnieImageDiT": {
"diffsynth.models.ernie_image_dit.ErnieImageRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
} }
def QwenImageTextEncoder_Module_Map_Updater(): def QwenImageTextEncoder_Module_Map_Updater():

View File

@@ -4,7 +4,7 @@ from typing_extensions import Literal
class FlowMatchScheduler(): class FlowMatchScheduler():
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"): def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image"] = "FLUX.1"):
self.set_timesteps_fn = { self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux, "FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan, "Wan": FlowMatchScheduler.set_timesteps_wan,
@@ -13,6 +13,7 @@ class FlowMatchScheduler():
"Z-Image": FlowMatchScheduler.set_timesteps_z_image, "Z-Image": FlowMatchScheduler.set_timesteps_z_image,
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2, "LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning, "Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
"ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image,
}.get(template, FlowMatchScheduler.set_timesteps_flux) }.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000 self.num_train_timesteps = 1000
@@ -129,6 +130,15 @@ class FlowMatchScheduler():
timesteps = sigmas * num_train_timesteps timesteps = sigmas * num_train_timesteps
return sigmas, timesteps return sigmas, timesteps
@staticmethod
def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0):
"""ERNIE-Image scheduler: pure linear sigmas from 1.0 to 0.0, no shift."""
num_train_timesteps = 1000
sigma_start = denoising_strength
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod @staticmethod
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None): def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
sigma_min = 0.0 sigma_min = 0.0
@@ -175,6 +185,9 @@ class FlowMatchScheduler():
return sigmas, timesteps return sigmas, timesteps
def set_training_weight(self): def set_training_weight(self):
if self.set_timesteps_fn == FlowMatchScheduler.set_timesteps_ernie_image:
self.set_uniform_training_weight()
return
steps = 1000 steps = 1000
x = self.timesteps x = self.timesteps
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2) y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
@@ -185,6 +198,13 @@ class FlowMatchScheduler():
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps) bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1] bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
self.linear_timesteps_weights = bsmntw_weighing self.linear_timesteps_weights = bsmntw_weighing
def set_uniform_training_weight(self):
"""Assign equal weight to every timestep, suitable for linear schedulers like ERNIE-Image."""
steps = 1000
num_steps = len(self.timesteps)
uniform_weight = torch.full((num_steps,), steps / num_steps, dtype=self.timesteps.dtype)
self.linear_timesteps_weights = uniform_weight
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs): def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
self.sigmas, self.timesteps = self.set_timesteps_fn( self.sigmas, self.timesteps = self.set_timesteps_fn(

View File

@@ -1,4 +1,4 @@
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessor from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch import torch
@@ -40,7 +40,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
value_bias = False value_bias = False
) )
super().__init__(config) super().__init__(config)
self.processor = DINOv3ViTImageProcessor( self.processor = DINOv3ViTImageProcessorFast(
crop_size = None, crop_size = None,
data_format = "channels_first", data_format = "channels_first",
default_to_square = True, default_to_square = True,
@@ -56,7 +56,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
0.456, 0.456,
0.406 0.406
], ],
image_processor_type = "DINOv3ViTImageProcessor", image_processor_type = "DINOv3ViTImageProcessorFast",
image_std = [ image_std = [
0.229, 0.229,
0.224, 0.224,

View File

@@ -0,0 +1,362 @@
"""
Ernie-Image DiT for DiffSynth-Studio.
Refactored from diffusers ErnieImageTransformer2DModel to use DiffSynth core modules.
Default parameters from actual checkpoint config.json (baidu/ERNIE-Image transformer).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from ..core.attention import attention_forward
from ..core.gradient import gradient_checkpoint_forward
from .flux2_dit import Timesteps, TimestepEmbedding
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta ** scale)
out = torch.einsum("...n,d->...nd", pos, omega)
return out.float()
class ErnieImageEmbedND3(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = list(axes_dim)
def forward(self, ids: torch.Tensor) -> torch.Tensor:
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
emb = emb.unsqueeze(2)
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1)
class ErnieImagePatchEmbedDynamic(nn.Module):
def __init__(self, in_channels: int, embed_dim: int, patch_size: int):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
batch_size, dim, height, width = x.shape
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
class ErnieImageSingleStreamAttnProcessor:
def __call__(
self,
attn: "ErnieImageAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
rot_dim = freqs_cis.shape[-1]
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
cos_ = torch.cos(freqs_cis).to(x.dtype)
sin_ = torch.sin(freqs_cis).to(x.dtype)
x1, x2 = x.chunk(2, dim=-1)
x_rotated = torch.cat((-x2, x1), dim=-1)
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask[:, None, None, :]
hidden_states = attention_forward(
query, key, value,
q_pattern="b s n d",
k_pattern="b s n d",
v_pattern="b s n d",
out_pattern="b s n d",
attn_mask=attention_mask,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
output = attn.to_out[0](hidden_states)
return output
class ErnieImageAttention(nn.Module):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: str = "rms_norm",
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
elementwise_affine: bool = True,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.out_dim = out_dim if out_dim is not None else query_dim
self.heads = out_dim // dim_head if out_dim is not None else heads
self.use_bias = bias
self.dropout = dropout
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
if qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
elif qk_norm == "rms_norm":
self.norm_q = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
else:
raise ValueError(
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'rms_norm'."
)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.processor = ErnieImageSingleStreamAttnProcessor()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.processor(self, hidden_states, attention_mask, image_rotary_emb)
class ErnieImageFeedForward(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
class ErnieImageRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states * self.weight
return hidden_states.to(input_dtype)
class ErnieImageSharedAdaLNBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
ffn_hidden_size: int,
eps: float = 1e-6,
qk_layernorm: bool = True,
):
super().__init__()
self.adaLN_sa_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
self.self_attention = ErnieImageAttention(
query_dim=hidden_size,
dim_head=hidden_size // num_heads,
heads=num_heads,
qk_norm="rms_norm" if qk_layernorm else None,
eps=eps,
bias=False,
out_bias=False,
)
self.adaLN_mlp_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size)
def forward(
self,
x: torch.Tensor,
rotary_pos_emb: torch.Tensor,
temb: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
residual = x
x = self.adaLN_sa_ln(x)
x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
x_bsh = x.permute(1, 0, 2)
attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
attn_out = attn_out.permute(1, 0, 2)
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
residual = x
x = self.adaLN_mlp_ln(x)
x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype)
class ErnieImageAdaLNContinuous(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps)
self.linear = nn.Linear(hidden_size, hidden_size * 2)
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
x = self.norm(x)
x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
return x
class ErnieImageDiT(nn.Module):
"""
Ernie-Image DiT model for DiffSynth-Studio.
Architecture: SharedAdaLN + RoPE 3D + Joint Image-Text Attention.
Internal format: [S, B, H] for transformer blocks, [B, S, H] for attention.
"""
def __init__(
self,
hidden_size: int = 4096,
num_attention_heads: int = 32,
num_layers: int = 36,
ffn_hidden_size: int = 12288,
in_channels: int = 128,
out_channels: int = 128,
patch_size: int = 1,
text_in_dim: int = 3072,
rope_theta: int = 256,
rope_axes_dim: Tuple[int, int, int] = (32, 48, 48),
eps: float = 1e-6,
qk_layernorm: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.head_dim = hidden_size // num_attention_heads
self.num_layers = num_layers
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.text_in_dim = text_in_dim
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size)
self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0)
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size)
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size))
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
self.layers = nn.ModuleList([
ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm)
for _ in range(num_layers)
])
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps)
self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
nn.init.zeros_(self.final_linear.weight)
nn.init.zeros_(self.final_linear.bias)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
text_bth: torch.Tensor,
text_lens: torch.Tensor,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
) -> torch.Tensor:
device, dtype = hidden_states.device, hidden_states.dtype
B, C, H, W = hidden_states.shape
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
N_img = Hp * Wp
img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous()
if self.text_proj is not None and text_bth.numel() > 0:
text_bth = self.text_proj(text_bth)
Tmax = text_bth.shape[1]
text_sbh = text_bth.transpose(0, 1).contiguous()
x = torch.cat([img_sbh, text_sbh], dim=0)
S = x.shape[0]
text_ids = torch.cat([
torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1),
torch.zeros((B, Tmax, 2), device=device)
], dim=-1) if Tmax > 0 else torch.zeros((B, 0, 3), device=device)
grid_yx = torch.stack(
torch.meshgrid(torch.arange(Hp, device=device, dtype=torch.float32),
torch.arange(Wp, device=device, dtype=torch.float32), indexing="ij"),
dim=-1
).reshape(-1, 2)
image_ids = torch.cat([
text_lens.float().view(B, 1, 1).expand(-1, N_img, -1),
grid_yx.view(1, N_img, 2).expand(B, -1, -1)
], dim=-1)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool)
attention_mask = torch.cat([
torch.ones((B, N_img), device=device, dtype=torch.bool),
valid_text
], dim=1)[:, None, None, :]
sample = self.time_proj(timestep.to(dtype))
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
c = self.time_embedding(sample)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
t.unsqueeze(0).expand(S, -1, -1).contiguous()
for t in self.adaLN_modulation(c).chunk(6, dim=-1)
]
for layer in self.layers:
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
if torch.is_grad_enabled() and use_gradient_checkpointing:
x = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x,
rotary_pos_emb,
temb,
attention_mask,
)
else:
x = layer(x, rotary_pos_emb, temb, attention_mask)
x = self.final_norm(x, c).type_as(x)
patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous()
output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W)
return output

View File

@@ -0,0 +1,76 @@
"""
Ernie-Image TextEncoder for DiffSynth-Studio.
Wraps transformers Ministral3Model to output text embeddings.
Pattern: lazy import + manual config dict + torch.nn.Module wrapper.
Only loads the text (language) model, ignoring vision components.
"""
import torch
class ErnieImageTextEncoder(torch.nn.Module):
"""
Text encoder using Ministral3Model (transformers).
Only the text_config portion of the full Mistral3Model checkpoint.
Uses the base model (no lm_head) since the checkpoint only has embeddings.
"""
def __init__(self):
super().__init__()
from transformers import Ministral3Config, Ministral3Model
text_config = {
"attention_dropout": 0.0,
"bos_token_id": 1,
"dtype": "bfloat16",
"eos_token_id": 2,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 9216,
"max_position_embeddings": 262144,
"model_type": "ministral3",
"num_attention_heads": 32,
"num_hidden_layers": 26,
"num_key_value_heads": 8,
"pad_token_id": 11,
"rms_norm_eps": 1e-05,
"rope_parameters": {
"beta_fast": 32.0,
"beta_slow": 1.0,
"factor": 16.0,
"llama_4_scaling_beta": 0.1,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 16384,
"rope_theta": 1000000.0,
"rope_type": "yarn",
"type": "yarn",
},
"sliding_window": None,
"tie_word_embeddings": True,
"use_cache": True,
"vocab_size": 131072,
}
config = Ministral3Config(**text_config)
self.model = Ministral3Model(config)
self.config = config
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=True,
return_dict=True,
**kwargs,
)
return (outputs.hidden_states,)

View File

@@ -1,5 +1,5 @@
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessor from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
import torch import torch
from diffsynth.core.device.npu_compatible_device import get_device_type from diffsynth.core.device.npu_compatible_device import get_device_type
@@ -90,7 +90,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
transformers_version = "4.57.1" transformers_version = "4.57.1"
) )
super().__init__(config) super().__init__(config)
self.processor = Siglip2ImageProcessor( self.processor = Siglip2ImageProcessorFast(
**{ **{
"data_format": "channels_first", "data_format": "channels_first",
"default_to_square": True, "default_to_square": True,
@@ -106,7 +106,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
0.5, 0.5,
0.5 0.5
], ],
"image_processor_type": "Siglip2ImageProcessor", "image_processor_type": "Siglip2ImageProcessorFast",
"image_std": [ "image_std": [
0.5, 0.5,
0.5, 0.5,

View File

@@ -0,0 +1,265 @@
"""
ERNIE-Image Text-to-Image Pipeline for DiffSynth-Studio.
Architecture: SharedAdaLN DiT + RoPE 3D + Joint Image-Text Attention.
"""
import torch
from typing import Union, Optional
from tqdm import tqdm
from transformers import AutoTokenizer
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.ernie_image_text_encoder import ErnieImageTextEncoder
from ..models.ernie_image_dit import ErnieImageDiT
from ..models.flux2_vae import Flux2VAE
class ErnieImagePipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("ERNIE-Image")
self.text_encoder: ErnieImageTextEncoder = None
self.dit: ErnieImageDiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("dit",)
self.units = [
ErnieImageUnit_ShapeChecker(),
ErnieImageUnit_PromptEmbedder(),
ErnieImageUnit_NoiseInitializer(),
ErnieImageUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_ernie_image
self.compilable_models = ["dit"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
):
pipe = ErnieImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("ernie_image_text_encoder")
pipe.dit = model_pool.fetch_model("ernie_image_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 4.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cuda",
# Steps
num_inference_steps: int = 50,
# Progress bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"height": height, "width": width, "seed": seed,
"cfg_scale": cfg_scale, "num_inference_steps": num_inference_steps,
"rand_device": rand_device,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
latents = inputs_shared["latents"]
image = self.vae.decode(latents)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class ErnieImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: ErnieImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class ErnieImageUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds", "prompt_embeds_mask"),
onload_model_names=("text_encoder",)
)
def encode_prompt(self, pipe: ErnieImagePipeline, prompt):
if isinstance(prompt, str):
prompt = [prompt]
text_hiddens = []
text_lens_list = []
for p in prompt:
ids = pipe.tokenizer(
p,
add_special_tokens=True,
truncation=True,
padding=False,
)["input_ids"]
if len(ids) == 0:
if pipe.tokenizer.bos_token_id is not None:
ids = [pipe.tokenizer.bos_token_id]
else:
ids = [0]
input_ids = torch.tensor([ids], device=pipe.device)
outputs = pipe.text_encoder(
input_ids=input_ids,
)
# Text encoder returns tuple of (hidden_states_tuple,) where each layer's hidden state is included
all_hidden_states = outputs[0]
hidden = all_hidden_states[-2][0] # [T, H] - second to last layer
text_hiddens.append(hidden)
text_lens_list.append(hidden.shape[0])
# Pad to uniform length
if len(text_hiddens) == 0:
text_in_dim = pipe.text_encoder.config.hidden_size if hasattr(pipe.text_encoder, 'config') else 3072
return {
"prompt_embeds": torch.zeros((0, 0, text_in_dim), device=pipe.device, dtype=pipe.torch_dtype),
"prompt_embeds_mask": torch.zeros((0,), device=pipe.device, dtype=torch.long),
}
normalized = [th.to(pipe.device).to(pipe.torch_dtype) for th in text_hiddens]
text_lens = torch.tensor([t.shape[0] for t in normalized], device=pipe.device, dtype=torch.long)
Tmax = int(text_lens.max().item())
text_in_dim = normalized[0].shape[1]
text_bth = torch.zeros((len(normalized), Tmax, text_in_dim), device=pipe.device, dtype=pipe.torch_dtype)
for i, t in enumerate(normalized):
text_bth[i, :t.shape[0], :] = t
return {"prompt_embeds": text_bth, "prompt_embeds_mask": text_lens}
def process(self, pipe: ErnieImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
if pipe.text_encoder is not None:
return self.encode_prompt(pipe, prompt)
return {}
class ErnieImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: ErnieImagePipeline, height, width, seed, rand_device):
latent_h = height // pipe.height_division_factor
latent_w = width // pipe.width_division_factor
latent_channels = pipe.dit.in_channels
# Use pipeline device if rand_device is not specified
if rand_device is None:
rand_device = str(pipe.device)
noise = pipe.generate_noise(
(1, latent_channels, latent_h, latent_w),
seed=seed,
rand_device=rand_device,
rand_torch_dtype=pipe.torch_dtype,
)
return {"noise": noise}
class ErnieImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: ErnieImagePipeline, input_image, noise):
if input_image is None:
# T2I path: use noise directly as initial latents
return {"latents": noise, "input_latents": None}
# I2I path: VAE encode input image
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
# In inference mode, add noise to encoded latents
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents}
def model_fn_ernie_image(
dit: ErnieImageDiT,
latents=None,
timestep=None,
prompt_embeds=None,
prompt_embeds_mask=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
output = dit(
hidden_states=latents,
timestep=timestep,
text_bth=prompt_embeds,
text_lens=prompt_embeds_mask,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return output

View File

@@ -95,7 +95,7 @@ class ZImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str = "", prompt: str,
negative_prompt: str = "", negative_prompt: str = "",
cfg_scale: float = 1.0, cfg_scale: float = 1.0,
# Image # Image
@@ -109,7 +109,7 @@ class ZImagePipeline(BasePipeline):
width: int = 1024, width: int = 1024,
# Randomness # Randomness
seed: int = None, seed: int = None,
rand_device: Union[str, torch.device] = "cpu", rand_device: str = "cpu",
# Steps # Steps
num_inference_steps: int = 8, num_inference_steps: int = 8,
sigma_shift: float = None, sigma_shift: float = None,

View File

@@ -0,0 +1,21 @@
def ErnieImageTextEncoderStateDictConverter(state_dict):
"""
Maps checkpoint keys from multimodal Mistral3Model format
to text-only Ministral3Model format.
Checkpoint keys (Mistral3Model):
language_model.model.layers.0.input_layernorm.weight
language_model.model.norm.weight
Model keys (ErnieImageTextEncoder → self.model = Ministral3Model):
model.layers.0.input_layernorm.weight
model.norm.weight
Mapping: language_model. → model.
"""
new_state_dict = {}
for key in state_dict:
if key.startswith("language_model.model."):
new_key = key.replace("language_model.model.", "model.", 1)
new_state_dict[new_key] = state_dict[key]
return new_state_dict

View File

@@ -0,0 +1,133 @@
# ERNIE-Image
ERNIE-Image is a powerful image generation model with 8B parameters developed by Baidu, featuring a compact and efficient architecture with strong instruction-following capability. Based on an 8B DiT backbone, it delivers performance comparable to larger (20B+) models in certain scenarios while maintaining parameter efficiency. It offers reliable performance in instruction understanding and execution, text generation (English/Chinese/Japanese), and overall stability.
## Installation
Before performing model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Running the following code will load the [baidu/ERNIE-Image](https://www.modelscope.cn/models/baidu/ERNIE-Image) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 3G VRAM.
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[baidu/ERNIE-Image: T2I](https://www.modelscope.cn/models/baidu/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py)|
## Model Inference
The model is loaded via `ErnieImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
The input parameters for `ErnieImagePipeline` inference include:
* `prompt`: The prompt describing the content to appear in the image.
* `negative_prompt`: The negative prompt describing what should not appear in the image, default value is `""`.
* `cfg_scale`: Classifier-free guidance parameter, default value is 4.0.
* `height`: Image height, must be a multiple of 16, default value is 1024.
* `width`: Image width, must be a multiple of 16, default value is 1024.
* `seed`: Random seed. Default is `None`, meaning completely random.
* `rand_device`: The computing device for generating random Gaussian noise matrices, default is `"cuda"`. When set to `cuda`, different GPUs will produce different results.
* `num_inference_steps`: Number of inference steps, default value is 50.
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low-VRAM configurations for each model in the "Model Overview" table above.
## Model Training
ERNIE-Image series models are trained uniformly via [`examples/ernie_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/train.py). The script parameters include:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Path to the dataset metadata file.
* `--dataset_repeat`: Number of dataset repeats per epoch.
* `--dataset_num_workers`: Number of processes per DataLoader.
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths to load models from, in JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"baidu/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors"`, separated by commas.
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
* Basic Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
* `--weight_decay`: Weight decay magnitude.
* `--task`: Training task, defaults to `sft`.
* Output Configuration
* `--output_path`: Path to save the model.
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
* `--save_steps`: Interval in training steps to save the model.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to LoRA checkpoint.
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Resolution Configuration
* `--height`: Height of the image. Leave empty to enable dynamic resolution.
* `--width`: Width of the image. Leave empty to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
* ERNIE-Image Specific Parameters
* `--tokenizer_path`: Path to the tokenizer, leave empty to auto-download from remote.
We provide an example image dataset for testing, which can be downloaded with the following command:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -29,6 +29,7 @@ Welcome to DiffSynth-Studio's Documentation
Model_Details/Z-Image Model_Details/Z-Image
Model_Details/Anima Model_Details/Anima
Model_Details/LTX-2 Model_Details/LTX-2
Model_Details/ERNIE-Image
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2

View File

@@ -0,0 +1,133 @@
# ERNIE-Image
ERNIE-Image 是百度推出的拥有 8B 参数的图像生成模型,具有紧凑高效的架构和出色的指令跟随能力。基于 8B DiT 主干网络,其在某些场景下的性能可与 20B 以上的更大模型相媲美,同时保持了良好的参数效率。该模型在指令理解与执行、文本生成(如英文/中文/日文)以及整体稳定性方面提供了较为可靠的表现。
## 安装
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
## 快速开始
运行以下代码可以快速加载 [baidu/ERNIE-Image](https://www.modelscope.cn/models/baidu/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
## 模型总览
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[baidu/ERNIE-Image: T2I](https://www.modelscope.cn/models/baidu/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py)|
## 模型推理
模型通过 `ErnieImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
`ErnieImagePipeline` 推理的输入参数包括:
* `prompt`: 提示词,描述画面中出现的内容。
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 4.0。
* `height`: 图像高度,需保证高度为 16 的倍数,默认值为 1024。
* `width`: 图像宽度,需保证宽度为 16 的倍数,默认值为 1024。
* `seed`: 随机种子。默认为 `None`,即完全随机。
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cuda"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
* `num_inference_steps`: 推理步数,默认值为 50。
如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
## 模型训练
ERNIE-Image 系列模型统一通过 [`examples/ernie_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/train.py) 进行训练,脚本的参数包括:
* 通用训练参数
* 数据集基础配置
* `--dataset_base_path`: 数据集的根目录。
* `--dataset_metadata_path`: 数据集的元数据文件路径。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
* 模型加载配置
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 `"baidu/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors"`。用逗号分隔。
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
* `--fp8_models`:以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
* 训练基础配置
* `--learning_rate`: 学习率。
* `--num_epochs`: 轮数Epoch
* `--trainable_models`: 可训练的模型,例如 `dit``vae``text_encoder`
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
* `--weight_decay`:权重衰减大小。
* `--task`: 训练任务,默认为 `sft`
* 输出配置
* `--output_path`: 模型保存路径。
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
* `--save_steps`: 保存模型的训练步数间隔。
* LoRA 配置
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪些层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`
* 梯度配置
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
* `--gradient_accumulation_steps`: 梯度累积步数。
* 分辨率配置
* `--height`: 图像的高度。留空启用动态分辨率。
* `--width`: 图像的宽度。留空启用动态分辨率。
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
* ERNIE-Image 专有参数
* `--tokenizer_path`: tokenizer 的路径,留空则自动从远程下载。
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。

View File

@@ -29,6 +29,7 @@
Model_Details/Z-Image Model_Details/Z-Image
Model_Details/Anima Model_Details/Anima
Model_Details/LTX-2 Model_Details/LTX-2
Model_Details/ERNIE-Image
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2

View File

@@ -1,283 +0,0 @@
import importlib, inspect, pkgutil, traceback, torch, os, re
from typing import Union, List, Optional, Tuple, Iterable, Dict
from contextlib import contextmanager
import streamlit as st
from diffsynth import ModelConfig
from diffsynth.diffusion.base_pipeline import ControlNetInput
from PIL import Image
from tqdm import tqdm
st.set_page_config(layout="wide")
class StreamlitTqdmWrapper:
"""Wrapper class that combines tqdm and streamlit progress bar"""
def __init__(self, iterable, st_progress_bar=None):
self.iterable = iterable
self.st_progress_bar = st_progress_bar
self.tqdm_bar = tqdm(iterable)
self.total = len(iterable) if hasattr(iterable, '__len__') else None
self.current = 0
def __iter__(self):
for item in self.tqdm_bar:
if self.st_progress_bar is not None and self.total is not None:
self.current += 1
self.st_progress_bar.progress(self.current / self.total)
yield item
def __enter__(self):
return self
def __exit__(self, *args):
if hasattr(self.tqdm_bar, '__exit__'):
self.tqdm_bar.__exit__(*args)
@contextmanager
def catch_error(error_value):
try:
yield
except Exception as e:
error_message = traceback.format_exc()
print(f"Error {error_value}:\n{error_message}")
def parse_model_configs_from_an_example(path):
model_configs = []
with open(path, "r") as f:
for code in f.readlines():
code = code.strip()
if not code.startswith("ModelConfig"):
continue
pairs = re.findall(r'(\w+)\s*=\s*["\']([^"\']+)["\']', code)
config_dict = {k: v for k, v in pairs}
model_configs.append(ModelConfig(model_id=config_dict["model_id"], origin_file_pattern=config_dict["origin_file_pattern"]))
return model_configs
def list_examples(path, keyword=None):
examples = []
if os.path.isdir(path):
for file_name in os.listdir(path):
examples.extend(list_examples(os.path.join(path, file_name), keyword=keyword))
elif path.endswith(".py"):
with open(path, "r") as f:
code = f.read()
if keyword is None or keyword in code:
examples.extend([path])
return examples
def parse_available_pipelines():
from diffsynth.diffusion.base_pipeline import BasePipeline
import diffsynth.pipelines as _pipelines_pkg
available_pipelines = {}
for _, name, _ in pkgutil.iter_modules(_pipelines_pkg.__path__):
with catch_error(f"Failed: import diffsynth.pipelines.{name}"):
mod = importlib.import_module(f"diffsynth.pipelines.{name}")
classes = {
cls_name: cls for cls_name, cls in inspect.getmembers(mod, inspect.isclass)
if issubclass(cls, BasePipeline) and cls is not BasePipeline and cls.__module__ == mod.__name__
}
available_pipelines.update(classes)
return available_pipelines
def parse_available_examples(path, available_pipelines):
available_examples = {}
for pipeline_name in available_pipelines:
examples = ["None"] + list_examples(path, keyword=f"{pipeline_name}.from_pretrained")
available_examples[pipeline_name] = examples
return available_examples
def draw_selectbox(label, options, option_map, value=None, disabled=False):
default_index = 0 if value is None else tuple(options).index([option for option in option_map if option_map[option]==value][0])
option = st.selectbox(label=label, options=tuple(options), index=default_index, disabled=disabled)
return option_map.get(option)
def parse_params(fn):
params = []
for name, param in inspect.signature(fn).parameters.items():
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else None
default = param.default if param.default is not inspect.Parameter.empty else None
params.append({"name": name, "dtype": annotation, "value": default})
return params
def draw_model_config(model_config=None, key_suffix="", disabled=False):
with st.container(border=True):
if model_config is None:
model_config = ModelConfig()
path = st.text_input(label="path", key="path" + key_suffix, value=model_config.path, disabled=disabled)
col1, col2 = st.columns(2)
with col1:
model_id = st.text_input(label="model_id", key="model_id" + key_suffix, value=model_config.model_id, disabled=disabled)
with col2:
origin_file_pattern = st.text_input(label="origin_file_pattern", key="origin_file_pattern" + key_suffix, value=model_config.origin_file_pattern, disabled=disabled)
model_config = ModelConfig(
path=None if path == "" else path,
model_id=model_id,
origin_file_pattern=origin_file_pattern,
)
return model_config
def draw_multi_model_config(name="", value=None, disabled=False):
model_configs = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
model_config = draw_model_config(key_suffix=f"_{name}_{i}", model_config=None if value is None else value[i], disabled=disabled)
model_configs.append(model_config)
return model_configs
def draw_single_model_config(name="", value=None, disabled=False):
with st.container(border=True):
st.markdown(name)
model_config = draw_model_config(value, key_suffix=f"_{name}", disabled=disabled)
return model_config
def draw_multi_images(name="", value=None, disabled=False):
images = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
image = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], key=f"{name}_{i}", disabled=disabled)
if image is not None: images.append(Image.open(image))
return images
def draw_controlnet_input(name="", value=None, disabled=False):
with st.container(border=True):
st.markdown(name)
controlnet_id = st.number_input("controlnet_id", value=0, min_value=0, max_value=20, step=1, key=f"{name}_controlnet_id")
scale = st.number_input("scale", value=1.0, min_value=0.0, max_value=10.0, key=f"{name}_scale")
image = st.file_uploader("image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_image")
if image is not None: image = Image.open(image)
inpaint_image = st.file_uploader("inpaint_image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_image")
if inpaint_image is not None: inpaint_image = Image.open(inpaint_image)
inpaint_mask = st.file_uploader("inpaint_mask", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_mask")
if inpaint_mask is not None: inpaint_mask = Image.open(inpaint_mask)
return ControlNetInput(controlnet_id=controlnet_id, scale=scale, image=image, inpaint_image=inpaint_image, inpaint_mask=inpaint_mask)
def draw_controlnet_inputs(name, value=None, disabled=False):
controlnet_inputs = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
controlnet_input = draw_controlnet_input(name=f"{name}_{i}", value=None, disabled=disabled)
controlnet_inputs.append(controlnet_input)
return controlnet_inputs
def draw_ui_element(name, dtype, value):
unsupported_dtype = [
Dict[str, torch.Tensor],
torch.Tensor,
]
if dtype in unsupported_dtype:
return
if value is None:
with st.container(border=True):
enable = st.checkbox(f"Enable {name}", value=False)
ui = draw_ui_element_safely(name, dtype, value, disabled=not enable)
if enable:
return ui
else:
return None
else:
return draw_ui_element_safely(name, dtype, value)
def draw_ui_element_safely(name, dtype, value, disabled=False):
if dtype == torch.dtype:
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
elif dtype == Union[str, torch.device]:
option_map = {"cuda": "cuda", "cpu": "cpu"}
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
elif dtype == bool:
ui = st.checkbox(name, value, disabled=disabled)
elif dtype == ModelConfig:
ui = draw_single_model_config(name, value, disabled=disabled)
elif dtype == list[ModelConfig]:
if name == "model_configs" and "model_configs_from_example" in st.session_state:
model_configs = st.session_state["model_configs_from_example"]
del st.session_state["model_configs_from_example"]
ui = draw_multi_model_config(name, model_configs, disabled=disabled)
else:
ui = draw_multi_model_config(name, disabled=disabled)
elif dtype == str:
if "prompt" in name:
ui = st.text_area(name, value, height=3, disabled=disabled)
else:
ui = st.text_input(name, value, disabled=disabled)
elif dtype == float:
ui = st.number_input(name, value, disabled=disabled)
elif dtype == int:
ui = st.number_input(name, value, step=1, disabled=disabled)
elif dtype == Image.Image:
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
if ui is not None: ui = Image.open(ui)
elif dtype == List[Image.Image]:
ui = draw_multi_images(name, value, disabled=disabled)
elif dtype == List[ControlNetInput]:
ui = draw_controlnet_inputs(name, value, disabled=disabled)
elif dtype is None:
if name == "progress_bar_cmd":
ui = value
else:
st.markdown(f"(`{name}` is not not configurable in WebUI). dtype: `{dtype}`.")
ui = value
return ui
def launch_webui():
input_col, output_col = st.columns(2)
with input_col:
if "available_pipelines" not in st.session_state:
st.session_state["available_pipelines"] = parse_available_pipelines()
if "available_examples" not in st.session_state:
st.session_state["available_examples"] = parse_available_examples("./examples", st.session_state["available_pipelines"])
with st.expander("Pipeline", expanded=True):
pipeline_class = draw_selectbox("Pipeline Class", st.session_state["available_pipelines"].keys(), st.session_state["available_pipelines"], value=st.session_state["available_pipelines"]["ZImagePipeline"])
example = st.selectbox("Parse model configs from an example (optional)", st.session_state["available_examples"][pipeline_class.__name__])
if example != "None":
st.session_state["model_configs_from_example"] = parse_model_configs_from_an_example(example)
if st.button("Step 1: Parse Pipeline", type="primary"):
st.session_state["pipeline_class"] = pipeline_class
if "pipeline_class" not in st.session_state:
return
with st.expander("Model", expanded=True):
input_params = {}
params = parse_params(pipeline_class.from_pretrained)
for param in params:
input_params[param["name"]] = draw_ui_element(**param)
if st.button("Step 2: Load Models", type="primary"):
with st.spinner("Loading models", show_time=True):
if "pipe" in st.session_state:
del st.session_state["pipe"]
torch.cuda.empty_cache()
st.session_state["pipe"] = pipeline_class.from_pretrained(**input_params)
if "pipe" not in st.session_state:
return
with st.expander("Input", expanded=True):
pipe = st.session_state["pipe"]
input_params = {}
params = parse_params(pipe.__call__)
for param in params:
if param["name"] in ["self"]:
continue
input_params[param["name"]] = draw_ui_element(**param)
with output_col:
if st.button("Step 3: Generate", type="primary"):
if "progress_bar_cmd" in input_params:
input_params["progress_bar_cmd"] = lambda iterable: StreamlitTqdmWrapper(iterable, st.progress(0))
result = pipe(**input_params)
st.session_state["result"] = result
if "result" in st.session_state:
result = st.session_state["result"]
if isinstance(result, Image.Image):
st.image(result)
else:
print(f"unsupported result format: {result}")
launch_webui()
# streamlit run examples/dev_tools/webui.py --server.fileWatcherType none

View File

@@ -0,0 +1,24 @@
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")

View File

@@ -0,0 +1,36 @@
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")

View File

@@ -0,0 +1,17 @@
# Dataset: data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I/
accelerate launch --config_file examples/ernie_image/model_training/full/accelerate_config_zero3.yaml \
examples/ernie_image/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I \
--dataset_metadata_path data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "baidu/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors,baidu/ERNIE-Image:text_encoder/model.safetensors,baidu/ERNIE-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Ernie-Image-T2I_full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,19 @@
# Dataset: data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ernie_image/Ernie-Image-T2I/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ernie_image/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I \
--dataset_metadata_path data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "baidu/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors,baidu/ERNIE-Image:text_encoder/model.safetensors,baidu/ERNIE-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Ernie-Image-T2I_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -0,0 +1,129 @@
import torch, os, argparse, accelerate
from diffsynth.core import UnifiedDataset
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
from diffsynth.diffusion import *
from diffsynth.core.data.operators import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class ErnieImageTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
tokenizer_path=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
fp8_models=None,
offload_models=None,
device="cpu",
task="sft",
):
super().__init__()
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
self.pipe = ErnieImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
preset_lora_path, preset_lora_model,
task=task,
)
# Other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.task = task
self.task_to_loss = {
"sft:data_process": lambda pipe, inputs_shared, inputs_posi, inputs_nega: (inputs_shared, inputs_posi, inputs_nega),
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
}
def get_pipeline_inputs(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
"cfg_scale": 1,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
if inputs is None:
inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
loss = self.task_to_loss[self.task](self.pipe, *inputs)
return loss
def ernie_image_parser():
parser = argparse.ArgumentParser(description="ERNIE-Image training.")
parser = add_general_config(parser)
parser = add_image_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
return parser
if __name__ == "__main__":
parser = ernie_image_parser()
args = parser.parse_args()
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
)
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","),
main_data_operator=lambda x: x,
special_operator_map={
"image": ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16),
},
)
model = ErnieImageTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
tokenizer_path=args.tokenizer_path,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
preset_lora_path=args.preset_lora_path,
preset_lora_model=args.preset_lora_model,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
fp8_models=args.fp8_models,
offload_models=args.offload_models,
task=args.task,
device=accelerator.device,
)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
launcher_map = {
"sft:data_process": launch_data_process_task,
"sft": launch_training_task,
"sft:train": launch_training_task,
}
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)

View File

@@ -0,0 +1,25 @@
import torch
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
from diffsynth.core import load_state_dict
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
)
state_dict = load_state_dict("./models/train/Ernie-Image-T2I_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
image = pipe(
prompt="a professional photo of a cute dog",
seed=0,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("image_full.jpg")
print("Full validation image saved to image_full.jpg")

View File

@@ -0,0 +1,25 @@
import torch
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
from diffsynth.core.loader.file import load_state_dict
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
)
lora_state_dict = load_state_dict("./models/train/Ernie-Image-T2I_lora/epoch-4.safetensors", torch_dtype=torch.bfloat16, device="cuda")
pipe.load_lora(pipe.dit, state_dict=lora_state_dict, alpha=1.0)
image = pipe(
prompt="a professional photo of a cute dog",
seed=0,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("image_lora.jpg")
print("LoRA validation image saved to image_lora.jpg")