diff --git a/README.md b/README.md
index 660b8de..1fd1ed0 100644
--- a/README.md
+++ b/README.md
@@ -33,6 +33,7 @@ We believe that a well-developed open-source code framework can lower the thresh
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
+
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
@@ -876,6 +877,66 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
+#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
+
+
+
+Quick Start
+
+Running the following code will quickly load the [baidu/ERNIE-Image](https://www.modelscope.cn/models/baidu/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
+
+```python
+from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
+import torch
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+pipe = ErnieImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device='cuda',
+ model_configs=[
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+image = pipe(
+ prompt="一只黑白相间的中华田园犬",
+ negative_prompt="",
+ height=1024,
+ width=1024,
+ seed=42,
+ num_inference_steps=50,
+ cfg_scale=4.0,
+)
+image.save("output.jpg")
+```
+
+
+
+
+
+Examples
+
+Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
+
+| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
+|-|-|-|-|-|-|-|
+|[baidu/ERNIE-Image: T2I](https://www.modelscope.cn/models/baidu/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py)|
+
+
+
## Innovative Achievements
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
diff --git a/README_zh.md b/README_zh.md
index 92e230b..29e293a 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -877,6 +877,66 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
+#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md)
+
+
+
+快速开始
+
+运行以下代码可以快速加载 [baidu/ERNIE-Image](https://www.modelscope.cn/models/baidu/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
+
+```python
+from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
+import torch
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+pipe = ErnieImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device='cuda',
+ model_configs=[
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+image = pipe(
+ prompt="一只黑白相间的中华田园犬",
+ negative_prompt="",
+ height=1024,
+ width=1024,
+ seed=42,
+ num_inference_steps=50,
+ cfg_scale=4.0,
+)
+image.save("output.jpg")
+```
+
+
+
+
+
+示例代码
+
+ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/)
+
+| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
+|-|-|-|-|-|-|-|
+|[baidu/ERNIE-Image: T2I](https://www.modelscope.cn/models/baidu/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py)|
+
+
+
## 创新成果
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py
index 4222202..7428962 100644
--- a/diffsynth/configs/model_configs.py
+++ b/diffsynth/configs/model_configs.py
@@ -541,6 +541,22 @@ flux2_series = [
},
]
+ernie_image_series = [
+ {
+ # Example: ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
+ "model_hash": "584c13713849f1af4e03d5f1858b8b7b",
+ "model_name": "ernie_image_dit",
+ "model_class": "diffsynth.models.ernie_image_dit.ErnieImageDiT",
+ },
+ {
+ # Example: ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors")
+ "model_hash": "404ed9f40796a38dd34c1620f1920207",
+ "model_name": "ernie_image_text_encoder",
+ "model_class": "diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ernie_image_text_encoder.ErnieImageTextEncoderStateDictConverter",
+ },
+]
+
z_image_series = [
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
@@ -884,4 +900,4 @@ mova_series = [
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
},
]
-MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series
+MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series
diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py
index de27689..ba7ed77 100644
--- a/diffsynth/configs/vram_management_module_maps.py
+++ b/diffsynth/configs/vram_management_module_maps.py
@@ -267,6 +267,18 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
+ "diffsynth.models.ernie_image_dit.ErnieImageDiT": {
+ "diffsynth.models.ernie_image_dit.ErnieImageRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
+ },
+ "diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder": {
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
+ },
}
def QwenImageTextEncoder_Module_Map_Updater():
diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py
index 208fb1e..1ac5c49 100644
--- a/diffsynth/diffusion/flow_match.py
+++ b/diffsynth/diffusion/flow_match.py
@@ -4,7 +4,7 @@ from typing_extensions import Literal
class FlowMatchScheduler():
- def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
+ def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image"] = "FLUX.1"):
self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan,
@@ -13,6 +13,7 @@ class FlowMatchScheduler():
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
+ "ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image,
}.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@@ -129,6 +130,15 @@ class FlowMatchScheduler():
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
+ @staticmethod
+ def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0):
+ """ERNIE-Image scheduler: pure linear sigmas from 1.0 to 0.0, no shift."""
+ num_train_timesteps = 1000
+ sigma_start = denoising_strength
+ sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
+ timesteps = sigmas * num_train_timesteps
+ return sigmas, timesteps
+
@staticmethod
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
sigma_min = 0.0
@@ -175,6 +185,9 @@ class FlowMatchScheduler():
return sigmas, timesteps
def set_training_weight(self):
+ if self.set_timesteps_fn == FlowMatchScheduler.set_timesteps_ernie_image:
+ self.set_uniform_training_weight()
+ return
steps = 1000
x = self.timesteps
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
@@ -185,6 +198,13 @@ class FlowMatchScheduler():
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
self.linear_timesteps_weights = bsmntw_weighing
+
+ def set_uniform_training_weight(self):
+ """Assign equal weight to every timestep, suitable for linear schedulers like ERNIE-Image."""
+ steps = 1000
+ num_steps = len(self.timesteps)
+ uniform_weight = torch.full((num_steps,), steps / num_steps, dtype=self.timesteps.dtype)
+ self.linear_timesteps_weights = uniform_weight
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
self.sigmas, self.timesteps = self.set_timesteps_fn(
diff --git a/diffsynth/models/ernie_image_dit.py b/diffsynth/models/ernie_image_dit.py
new file mode 100644
index 0000000..fd0e022
--- /dev/null
+++ b/diffsynth/models/ernie_image_dit.py
@@ -0,0 +1,362 @@
+"""
+Ernie-Image DiT for DiffSynth-Studio.
+
+Refactored from diffusers ErnieImageTransformer2DModel to use DiffSynth core modules.
+Default parameters from actual checkpoint config.json (baidu/ERNIE-Image transformer).
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Tuple
+
+from ..core.attention import attention_forward
+from ..core.gradient import gradient_checkpoint_forward
+from .flux2_dit import Timesteps, TimestepEmbedding
+
+
+def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
+ assert dim % 2 == 0
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+ omega = 1.0 / (theta ** scale)
+ out = torch.einsum("...n,d->...nd", pos, omega)
+ return out.float()
+
+
+class ErnieImageEmbedND3(nn.Module):
+ def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]):
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = list(axes_dim)
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
+ emb = emb.unsqueeze(2)
+ return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1)
+
+
+class ErnieImagePatchEmbedDynamic(nn.Module):
+ def __init__(self, in_channels: int, embed_dim: int, patch_size: int):
+ super().__init__()
+ self.patch_size = patch_size
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ batch_size, dim, height, width = x.shape
+ return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
+
+
+class ErnieImageSingleStreamAttnProcessor:
+ def __call__(
+ self,
+ attn: "ErnieImageAttention",
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ rot_dim = freqs_cis.shape[-1]
+ x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
+ cos_ = torch.cos(freqs_cis).to(x.dtype)
+ sin_ = torch.sin(freqs_cis).to(x.dtype)
+ x1, x2 = x.chunk(2, dim=-1)
+ x_rotated = torch.cat((-x2, x1), dim=-1)
+ return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
+
+ if freqs_cis is not None:
+ query = apply_rotary_emb(query, freqs_cis)
+ key = apply_rotary_emb(key, freqs_cis)
+
+ if attention_mask is not None and attention_mask.ndim == 2:
+ attention_mask = attention_mask[:, None, None, :]
+
+ hidden_states = attention_forward(
+ query, key, value,
+ q_pattern="b s n d",
+ k_pattern="b s n d",
+ v_pattern="b s n d",
+ out_pattern="b s n d",
+ attn_mask=attention_mask,
+ )
+
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+ output = attn.to_out[0](hidden_states)
+
+ return output
+
+
+class ErnieImageAttention(nn.Module):
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ qk_norm: str = "rms_norm",
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ elementwise_affine: bool = True,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+
+ self.use_bias = bias
+ self.dropout = dropout
+
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if qk_norm == "layer_norm":
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ elif qk_norm == "rms_norm":
+ self.norm_q = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ else:
+ raise ValueError(
+ f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'rms_norm'."
+ )
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+
+ self.processor = ErnieImageSingleStreamAttnProcessor()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, attention_mask, image_rotary_emb)
+
+
+class ErnieImageFeedForward(nn.Module):
+ def __init__(self, hidden_size: int, ffn_hidden_size: int):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
+ self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
+ self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
+
+
+class ErnieImageRMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+ hidden_states = hidden_states * self.weight
+ return hidden_states.to(input_dtype)
+
+
+class ErnieImageSharedAdaLNBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ ffn_hidden_size: int,
+ eps: float = 1e-6,
+ qk_layernorm: bool = True,
+ ):
+ super().__init__()
+ self.adaLN_sa_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
+ self.self_attention = ErnieImageAttention(
+ query_dim=hidden_size,
+ dim_head=hidden_size // num_heads,
+ heads=num_heads,
+ qk_norm="rms_norm" if qk_layernorm else None,
+ eps=eps,
+ bias=False,
+ out_bias=False,
+ )
+ self.adaLN_mlp_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
+ self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ rotary_pos_emb: torch.Tensor,
+ temb: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
+ residual = x
+ x = self.adaLN_sa_ln(x)
+ x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
+ x_bsh = x.permute(1, 0, 2)
+ attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
+ attn_out = attn_out.permute(1, 0, 2)
+ x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
+ residual = x
+ x = self.adaLN_mlp_ln(x)
+ x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
+ return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype)
+
+
+class ErnieImageAdaLNContinuous(nn.Module):
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
+ super().__init__()
+ self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps)
+ self.linear = nn.Linear(hidden_size, hidden_size * 2)
+
+ def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
+ scale, shift = self.linear(conditioning).chunk(2, dim=-1)
+ x = self.norm(x)
+ x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
+ return x
+
+
+class ErnieImageDiT(nn.Module):
+ """
+ Ernie-Image DiT model for DiffSynth-Studio.
+
+ Architecture: SharedAdaLN + RoPE 3D + Joint Image-Text Attention.
+ Internal format: [S, B, H] for transformer blocks, [B, S, H] for attention.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ num_attention_heads: int = 32,
+ num_layers: int = 36,
+ ffn_hidden_size: int = 12288,
+ in_channels: int = 128,
+ out_channels: int = 128,
+ patch_size: int = 1,
+ text_in_dim: int = 3072,
+ rope_theta: int = 256,
+ rope_axes_dim: Tuple[int, int, int] = (32, 48, 48),
+ eps: float = 1e-6,
+ qk_layernorm: bool = True,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.num_heads = num_attention_heads
+ self.head_dim = hidden_size // num_attention_heads
+ self.num_layers = num_layers
+ self.patch_size = patch_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.text_in_dim = text_in_dim
+
+ self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size)
+ self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None
+ self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0)
+ self.time_embedding = TimestepEmbedding(hidden_size, hidden_size)
+ self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size))
+ nn.init.zeros_(self.adaLN_modulation[-1].weight)
+ nn.init.zeros_(self.adaLN_modulation[-1].bias)
+ self.layers = nn.ModuleList([
+ ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm)
+ for _ in range(num_layers)
+ ])
+ self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps)
+ self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
+ nn.init.zeros_(self.final_linear.weight)
+ nn.init.zeros_(self.final_linear.bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ text_bth: torch.Tensor,
+ text_lens: torch.Tensor,
+ use_gradient_checkpointing: bool = False,
+ use_gradient_checkpointing_offload: bool = False,
+ ) -> torch.Tensor:
+ device, dtype = hidden_states.device, hidden_states.dtype
+ B, C, H, W = hidden_states.shape
+ p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
+ N_img = Hp * Wp
+
+ img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous()
+
+ if self.text_proj is not None and text_bth.numel() > 0:
+ text_bth = self.text_proj(text_bth)
+ Tmax = text_bth.shape[1]
+ text_sbh = text_bth.transpose(0, 1).contiguous()
+
+ x = torch.cat([img_sbh, text_sbh], dim=0)
+ S = x.shape[0]
+
+ text_ids = torch.cat([
+ torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1),
+ torch.zeros((B, Tmax, 2), device=device)
+ ], dim=-1) if Tmax > 0 else torch.zeros((B, 0, 3), device=device)
+ grid_yx = torch.stack(
+ torch.meshgrid(torch.arange(Hp, device=device, dtype=torch.float32),
+ torch.arange(Wp, device=device, dtype=torch.float32), indexing="ij"),
+ dim=-1
+ ).reshape(-1, 2)
+ image_ids = torch.cat([
+ text_lens.float().view(B, 1, 1).expand(-1, N_img, -1),
+ grid_yx.view(1, N_img, 2).expand(B, -1, -1)
+ ], dim=-1)
+ rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
+
+ valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool)
+ attention_mask = torch.cat([
+ torch.ones((B, N_img), device=device, dtype=torch.bool),
+ valid_text
+ ], dim=1)[:, None, None, :]
+
+ sample = self.time_proj(timestep.to(dtype))
+ sample = sample.to(self.time_embedding.linear_1.weight.dtype)
+ c = self.time_embedding(sample)
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
+ t.unsqueeze(0).expand(S, -1, -1).contiguous()
+ for t in self.adaLN_modulation(c).chunk(6, dim=-1)
+ ]
+
+ for layer in self.layers:
+ temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
+ if torch.is_grad_enabled() and use_gradient_checkpointing:
+ x = gradient_checkpoint_forward(
+ layer,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ x,
+ rotary_pos_emb,
+ temb,
+ attention_mask,
+ )
+ else:
+ x = layer(x, rotary_pos_emb, temb, attention_mask)
+
+ x = self.final_norm(x, c).type_as(x)
+ patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous()
+ output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W)
+
+ return output
diff --git a/diffsynth/models/ernie_image_text_encoder.py b/diffsynth/models/ernie_image_text_encoder.py
new file mode 100644
index 0000000..17460b2
--- /dev/null
+++ b/diffsynth/models/ernie_image_text_encoder.py
@@ -0,0 +1,76 @@
+"""
+Ernie-Image TextEncoder for DiffSynth-Studio.
+
+Wraps transformers Ministral3Model to output text embeddings.
+Pattern: lazy import + manual config dict + torch.nn.Module wrapper.
+Only loads the text (language) model, ignoring vision components.
+"""
+
+import torch
+
+
+class ErnieImageTextEncoder(torch.nn.Module):
+ """
+ Text encoder using Ministral3Model (transformers).
+ Only the text_config portion of the full Mistral3Model checkpoint.
+ Uses the base model (no lm_head) since the checkpoint only has embeddings.
+ """
+
+ def __init__(self):
+ super().__init__()
+ from transformers import Ministral3Config, Ministral3Model
+
+ text_config = {
+ "attention_dropout": 0.0,
+ "bos_token_id": 1,
+ "dtype": "bfloat16",
+ "eos_token_id": 2,
+ "head_dim": 128,
+ "hidden_act": "silu",
+ "hidden_size": 3072,
+ "initializer_range": 0.02,
+ "intermediate_size": 9216,
+ "max_position_embeddings": 262144,
+ "model_type": "ministral3",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 26,
+ "num_key_value_heads": 8,
+ "pad_token_id": 11,
+ "rms_norm_eps": 1e-05,
+ "rope_parameters": {
+ "beta_fast": 32.0,
+ "beta_slow": 1.0,
+ "factor": 16.0,
+ "llama_4_scaling_beta": 0.1,
+ "mscale": 1.0,
+ "mscale_all_dim": 1.0,
+ "original_max_position_embeddings": 16384,
+ "rope_theta": 1000000.0,
+ "rope_type": "yarn",
+ "type": "yarn",
+ },
+ "sliding_window": None,
+ "tie_word_embeddings": True,
+ "use_cache": True,
+ "vocab_size": 131072,
+ }
+ config = Ministral3Config(**text_config)
+ self.model = Ministral3Model(config)
+ self.config = config
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ **kwargs,
+ ):
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_hidden_states=True,
+ return_dict=True,
+ **kwargs,
+ )
+ return (outputs.hidden_states,)
diff --git a/diffsynth/pipelines/ernie_image.py b/diffsynth/pipelines/ernie_image.py
new file mode 100644
index 0000000..a2b411d
--- /dev/null
+++ b/diffsynth/pipelines/ernie_image.py
@@ -0,0 +1,265 @@
+"""
+ERNIE-Image Text-to-Image Pipeline for DiffSynth-Studio.
+
+Architecture: SharedAdaLN DiT + RoPE 3D + Joint Image-Text Attention.
+"""
+
+import torch
+from typing import Union, Optional
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+from ..core.device.npu_compatible_device import get_device_type
+from ..diffusion import FlowMatchScheduler
+from ..core import ModelConfig
+from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
+from ..models.ernie_image_text_encoder import ErnieImageTextEncoder
+from ..models.ernie_image_dit import ErnieImageDiT
+from ..models.flux2_vae import Flux2VAE
+
+
+class ErnieImagePipeline(BasePipeline):
+
+ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
+ super().__init__(
+ device=device, torch_dtype=torch_dtype,
+ height_division_factor=16, width_division_factor=16,
+ )
+ self.scheduler = FlowMatchScheduler("ERNIE-Image")
+ self.text_encoder: ErnieImageTextEncoder = None
+ self.dit: ErnieImageDiT = None
+ self.vae: Flux2VAE = None
+ self.tokenizer: AutoTokenizer = None
+
+ self.in_iteration_models = ("dit",)
+ self.units = [
+ ErnieImageUnit_ShapeChecker(),
+ ErnieImageUnit_PromptEmbedder(),
+ ErnieImageUnit_NoiseInitializer(),
+ ErnieImageUnit_InputImageEmbedder(),
+ ]
+ self.model_fn = model_fn_ernie_image
+ self.compilable_models = ["dit"]
+
+ @staticmethod
+ def from_pretrained(
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: Union[str, torch.device] = get_device_type(),
+ model_configs: list[ModelConfig] = [],
+ tokenizer_config: ModelConfig = ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
+ vram_limit: float = None,
+ ):
+ pipe = ErnieImagePipeline(device=device, torch_dtype=torch_dtype)
+ model_pool = pipe.download_and_load_models(model_configs, vram_limit)
+
+ pipe.text_encoder = model_pool.fetch_model("ernie_image_text_encoder")
+ pipe.dit = model_pool.fetch_model("ernie_image_dit")
+ pipe.vae = model_pool.fetch_model("flux2_vae")
+
+ if tokenizer_config is not None:
+ tokenizer_config.download_if_necessary()
+ pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
+
+ pipe.vram_management_enabled = pipe.check_vram_management_state()
+ return pipe
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ # Prompt
+ prompt: str,
+ negative_prompt: str = "",
+ cfg_scale: float = 4.0,
+ # Shape
+ height: int = 1024,
+ width: int = 1024,
+ # Randomness
+ seed: int = None,
+ rand_device: str = "cuda",
+ # Steps
+ num_inference_steps: int = 50,
+ # Progress bar
+ progress_bar_cmd=tqdm,
+ ):
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps=num_inference_steps)
+
+ # Parameters
+ inputs_posi = {"prompt": prompt}
+ inputs_nega = {"negative_prompt": negative_prompt}
+ inputs_shared = {
+ "height": height, "width": width, "seed": seed,
+ "cfg_scale": cfg_scale, "num_inference_steps": num_inference_steps,
+ "rand_device": rand_device,
+ }
+ for unit in self.units:
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+
+ # Denoise
+ self.load_models_to_device(self.in_iteration_models)
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+ noise_pred = self.cfg_guided_model_fn(
+ self.model_fn, cfg_scale,
+ inputs_shared, inputs_posi, inputs_nega,
+ **models, timestep=timestep, progress_id=progress_id
+ )
+ inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ latents = inputs_shared["latents"]
+ image = self.vae.decode(latents)
+ image = self.vae_output_to_image(image)
+ self.load_models_to_device([])
+ return image
+
+
+class ErnieImageUnit_ShapeChecker(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("height", "width"),
+ output_params=("height", "width"),
+ )
+
+ def process(self, pipe: ErnieImagePipeline, height, width):
+ height, width = pipe.check_resize_height_width(height, width)
+ return {"height": height, "width": width}
+
+
+class ErnieImageUnit_PromptEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={"prompt": "prompt"},
+ input_params_nega={"prompt": "negative_prompt"},
+ output_params=("prompt_embeds", "prompt_embeds_mask"),
+ onload_model_names=("text_encoder",)
+ )
+
+ def encode_prompt(self, pipe: ErnieImagePipeline, prompt):
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ text_hiddens = []
+ text_lens_list = []
+ for p in prompt:
+ ids = pipe.tokenizer(
+ p,
+ add_special_tokens=True,
+ truncation=True,
+ padding=False,
+ )["input_ids"]
+
+ if len(ids) == 0:
+ if pipe.tokenizer.bos_token_id is not None:
+ ids = [pipe.tokenizer.bos_token_id]
+ else:
+ ids = [0]
+
+ input_ids = torch.tensor([ids], device=pipe.device)
+ outputs = pipe.text_encoder(
+ input_ids=input_ids,
+ )
+ # Text encoder returns tuple of (hidden_states_tuple,) where each layer's hidden state is included
+ all_hidden_states = outputs[0]
+ hidden = all_hidden_states[-2][0] # [T, H] - second to last layer
+ text_hiddens.append(hidden)
+ text_lens_list.append(hidden.shape[0])
+
+ # Pad to uniform length
+ if len(text_hiddens) == 0:
+ text_in_dim = pipe.text_encoder.config.hidden_size if hasattr(pipe.text_encoder, 'config') else 3072
+ return {
+ "prompt_embeds": torch.zeros((0, 0, text_in_dim), device=pipe.device, dtype=pipe.torch_dtype),
+ "prompt_embeds_mask": torch.zeros((0,), device=pipe.device, dtype=torch.long),
+ }
+
+ normalized = [th.to(pipe.device).to(pipe.torch_dtype) for th in text_hiddens]
+ text_lens = torch.tensor([t.shape[0] for t in normalized], device=pipe.device, dtype=torch.long)
+ Tmax = int(text_lens.max().item())
+ text_in_dim = normalized[0].shape[1]
+ text_bth = torch.zeros((len(normalized), Tmax, text_in_dim), device=pipe.device, dtype=pipe.torch_dtype)
+ for i, t in enumerate(normalized):
+ text_bth[i, :t.shape[0], :] = t
+
+ return {"prompt_embeds": text_bth, "prompt_embeds_mask": text_lens}
+
+ def process(self, pipe: ErnieImagePipeline, prompt):
+ pipe.load_models_to_device(self.onload_model_names)
+ if pipe.text_encoder is not None:
+ return self.encode_prompt(pipe, prompt)
+ return {}
+
+
+class ErnieImageUnit_NoiseInitializer(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("height", "width", "seed", "rand_device"),
+ output_params=("noise",),
+ )
+
+ def process(self, pipe: ErnieImagePipeline, height, width, seed, rand_device):
+ latent_h = height // pipe.height_division_factor
+ latent_w = width // pipe.width_division_factor
+ latent_channels = pipe.dit.in_channels
+
+ # Use pipeline device if rand_device is not specified
+ if rand_device is None:
+ rand_device = str(pipe.device)
+
+ noise = pipe.generate_noise(
+ (1, latent_channels, latent_h, latent_w),
+ seed=seed,
+ rand_device=rand_device,
+ rand_torch_dtype=pipe.torch_dtype,
+ )
+ return {"noise": noise}
+
+
+class ErnieImageUnit_InputImageEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("input_image", "noise"),
+ output_params=("latents", "input_latents"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: ErnieImagePipeline, input_image, noise):
+ if input_image is None:
+ # T2I path: use noise directly as initial latents
+ return {"latents": noise, "input_latents": None}
+
+ # I2I path: VAE encode input image
+ pipe.load_models_to_device(['vae'])
+ image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
+ input_latents = pipe.vae.encode(image)
+
+ if pipe.scheduler.training:
+ return {"latents": noise, "input_latents": input_latents}
+ else:
+ # In inference mode, add noise to encoded latents
+ latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
+ return {"latents": latents}
+
+
+def model_fn_ernie_image(
+ dit: ErnieImageDiT,
+ latents=None,
+ timestep=None,
+ prompt_embeds=None,
+ prompt_embeds_mask=None,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ **kwargs,
+):
+ output = dit(
+ hidden_states=latents,
+ timestep=timestep,
+ text_bth=prompt_embeds,
+ text_lens=prompt_embeds_mask,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ )
+ return output
diff --git a/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py b/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py
new file mode 100644
index 0000000..f444d76
--- /dev/null
+++ b/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py
@@ -0,0 +1,21 @@
+def ErnieImageTextEncoderStateDictConverter(state_dict):
+ """
+ Maps checkpoint keys from multimodal Mistral3Model format
+ to text-only Ministral3Model format.
+
+ Checkpoint keys (Mistral3Model):
+ language_model.model.layers.0.input_layernorm.weight
+ language_model.model.norm.weight
+
+ Model keys (ErnieImageTextEncoder → self.model = Ministral3Model):
+ model.layers.0.input_layernorm.weight
+ model.norm.weight
+
+ Mapping: language_model. → model.
+ """
+ new_state_dict = {}
+ for key in state_dict:
+ if key.startswith("language_model.model."):
+ new_key = key.replace("language_model.model.", "model.", 1)
+ new_state_dict[new_key] = state_dict[key]
+ return new_state_dict
diff --git a/docs/en/Model_Details/ERNIE-Image.md b/docs/en/Model_Details/ERNIE-Image.md
new file mode 100644
index 0000000..601b26c
--- /dev/null
+++ b/docs/en/Model_Details/ERNIE-Image.md
@@ -0,0 +1,133 @@
+# ERNIE-Image
+
+ERNIE-Image is a powerful image generation model with 8B parameters developed by Baidu, featuring a compact and efficient architecture with strong instruction-following capability. Based on an 8B DiT backbone, it delivers performance comparable to larger (20B+) models in certain scenarios while maintaining parameter efficiency. It offers reliable performance in instruction understanding and execution, text generation (English/Chinese/Japanese), and overall stability.
+
+## Installation
+
+Before performing model inference and training, please install DiffSynth-Studio first.
+
+```shell
+git clone https://github.com/modelscope/DiffSynth-Studio.git
+cd DiffSynth-Studio
+pip install -e .
+```
+
+For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
+
+## Quick Start
+
+Running the following code will load the [baidu/ERNIE-Image](https://www.modelscope.cn/models/baidu/ERNIE-Image) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 3G VRAM.
+
+```python
+from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
+import torch
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+pipe = ErnieImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device='cuda',
+ model_configs=[
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+image = pipe(
+ prompt="一只黑白相间的中华田园犬",
+ negative_prompt="",
+ height=1024,
+ width=1024,
+ seed=42,
+ num_inference_steps=50,
+ cfg_scale=4.0,
+)
+image.save("output.jpg")
+```
+
+## Model Overview
+
+|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
+|-|-|-|-|-|-|-|
+|[baidu/ERNIE-Image: T2I](https://www.modelscope.cn/models/baidu/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py)|
+
+## Model Inference
+
+The model is loaded via `ErnieImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
+
+The input parameters for `ErnieImagePipeline` inference include:
+
+* `prompt`: The prompt describing the content to appear in the image.
+* `negative_prompt`: The negative prompt describing what should not appear in the image, default value is `""`.
+* `cfg_scale`: Classifier-free guidance parameter, default value is 4.0.
+* `height`: Image height, must be a multiple of 16, default value is 1024.
+* `width`: Image width, must be a multiple of 16, default value is 1024.
+* `seed`: Random seed. Default is `None`, meaning completely random.
+* `rand_device`: The computing device for generating random Gaussian noise matrices, default is `"cuda"`. When set to `cuda`, different GPUs will produce different results.
+* `num_inference_steps`: Number of inference steps, default value is 50.
+
+If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low-VRAM configurations for each model in the "Model Overview" table above.
+
+## Model Training
+
+ERNIE-Image series models are trained uniformly via [`examples/ernie_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/train.py). The script parameters include:
+
+* General Training Parameters
+ * Dataset Configuration
+ * `--dataset_base_path`: Root directory of the dataset.
+ * `--dataset_metadata_path`: Path to the dataset metadata file.
+ * `--dataset_repeat`: Number of dataset repeats per epoch.
+ * `--dataset_num_workers`: Number of processes per DataLoader.
+ * `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
+ * Model Loading Configuration
+ * `--model_paths`: Paths to load models from, in JSON format.
+ * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"baidu/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors"`, separated by commas.
+ * `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
+ * `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
+ * Basic Training Configuration
+ * `--learning_rate`: Learning rate.
+ * `--num_epochs`: Number of epochs.
+ * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
+ * `--find_unused_parameters`: Whether unused parameters exist in DDP training.
+ * `--weight_decay`: Weight decay magnitude.
+ * `--task`: Training task, defaults to `sft`.
+ * Output Configuration
+ * `--output_path`: Path to save the model.
+ * `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
+ * `--save_steps`: Interval in training steps to save the model.
+ * LoRA Configuration
+ * `--lora_base_model`: Which model to add LoRA to.
+ * `--lora_target_modules`: Which layers to add LoRA to.
+ * `--lora_rank`: Rank of LoRA.
+ * `--lora_checkpoint`: Path to LoRA checkpoint.
+ * `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
+ * `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
+ * Gradient Configuration
+ * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
+ * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
+ * `--gradient_accumulation_steps`: Number of gradient accumulation steps.
+ * Resolution Configuration
+ * `--height`: Height of the image. Leave empty to enable dynamic resolution.
+ * `--width`: Width of the image. Leave empty to enable dynamic resolution.
+ * `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
+* ERNIE-Image Specific Parameters
+ * `--tokenizer_path`: Path to the tokenizer, leave empty to auto-download from remote.
+
+We provide an example image dataset for testing, which can be downloaded with the following command:
+
+```shell
+modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
+```
+
+We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
diff --git a/docs/en/index.rst b/docs/en/index.rst
index 4b933ca..16c47ee 100644
--- a/docs/en/index.rst
+++ b/docs/en/index.rst
@@ -29,6 +29,7 @@ Welcome to DiffSynth-Studio's Documentation
Model_Details/Z-Image
Model_Details/Anima
Model_Details/LTX-2
+ Model_Details/ERNIE-Image
.. toctree::
:maxdepth: 2
diff --git a/docs/zh/Model_Details/ERNIE-Image.md b/docs/zh/Model_Details/ERNIE-Image.md
new file mode 100644
index 0000000..f7acbe7
--- /dev/null
+++ b/docs/zh/Model_Details/ERNIE-Image.md
@@ -0,0 +1,133 @@
+# ERNIE-Image
+
+ERNIE-Image 是百度推出的拥有 8B 参数的图像生成模型,具有紧凑高效的架构和出色的指令跟随能力。基于 8B DiT 主干网络,其在某些场景下的性能可与 20B 以上的更大模型相媲美,同时保持了良好的参数效率。该模型在指令理解与执行、文本生成(如英文/中文/日文)以及整体稳定性方面提供了较为可靠的表现。
+
+## 安装
+
+在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
+
+```shell
+git clone https://github.com/modelscope/DiffSynth-Studio.git
+cd DiffSynth-Studio
+pip install -e .
+```
+
+更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
+
+## 快速开始
+
+运行以下代码可以快速加载 [baidu/ERNIE-Image](https://www.modelscope.cn/models/baidu/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
+
+```python
+from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
+import torch
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+pipe = ErnieImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device='cuda',
+ model_configs=[
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+image = pipe(
+ prompt="一只黑白相间的中华田园犬",
+ negative_prompt="",
+ height=1024,
+ width=1024,
+ seed=42,
+ num_inference_steps=50,
+ cfg_scale=4.0,
+)
+image.save("output.jpg")
+```
+
+## 模型总览
+
+|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
+|-|-|-|-|-|-|-|
+|[baidu/ERNIE-Image: T2I](https://www.modelscope.cn/models/baidu/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py)|
+
+## 模型推理
+
+模型通过 `ErnieImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
+
+`ErnieImagePipeline` 推理的输入参数包括:
+
+* `prompt`: 提示词,描述画面中出现的内容。
+* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。
+* `cfg_scale`: Classifier-free guidance 的参数,默认值为 4.0。
+* `height`: 图像高度,需保证高度为 16 的倍数,默认值为 1024。
+* `width`: 图像宽度,需保证宽度为 16 的倍数,默认值为 1024。
+* `seed`: 随机种子。默认为 `None`,即完全随机。
+* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cuda"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
+* `num_inference_steps`: 推理步数,默认值为 50。
+
+如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
+
+## 模型训练
+
+ERNIE-Image 系列模型统一通过 [`examples/ernie_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/train.py) 进行训练,脚本的参数包括:
+
+* 通用训练参数
+ * 数据集基础配置
+ * `--dataset_base_path`: 数据集的根目录。
+ * `--dataset_metadata_path`: 数据集的元数据文件路径。
+ * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
+ * `--dataset_num_workers`: 每个 Dataloader 的进程数量。
+ * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
+ * 模型加载配置
+ * `--model_paths`: 要加载的模型路径。JSON 格式。
+ * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"baidu/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors"`。用逗号分隔。
+ * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
+ * `--fp8_models`:以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
+ * 训练基础配置
+ * `--learning_rate`: 学习率。
+ * `--num_epochs`: 轮数(Epoch)。
+ * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
+ * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
+ * `--weight_decay`:权重衰减大小。
+ * `--task`: 训练任务,默认为 `sft`。
+ * 输出配置
+ * `--output_path`: 模型保存路径。
+ * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
+ * `--save_steps`: 保存模型的训练步数间隔。
+ * LoRA 配置
+ * `--lora_base_model`: LoRA 添加到哪个模型上。
+ * `--lora_target_modules`: LoRA 添加到哪些层上。
+ * `--lora_rank`: LoRA 的秩(Rank)。
+ * `--lora_checkpoint`: LoRA 检查点的路径。
+ * `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
+ * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
+ * 梯度配置
+ * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
+ * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
+ * `--gradient_accumulation_steps`: 梯度累积步数。
+ * 分辨率配置
+ * `--height`: 图像的高度。留空启用动态分辨率。
+ * `--width`: 图像的宽度。留空启用动态分辨率。
+ * `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
+* ERNIE-Image 专有参数
+ * `--tokenizer_path`: tokenizer 的路径,留空则自动从远程下载。
+
+我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
+
+```shell
+modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
+```
+
+我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
diff --git a/docs/zh/index.rst b/docs/zh/index.rst
index 42256b3..05fb09d 100644
--- a/docs/zh/index.rst
+++ b/docs/zh/index.rst
@@ -29,6 +29,7 @@
Model_Details/Z-Image
Model_Details/Anima
Model_Details/LTX-2
+ Model_Details/ERNIE-Image
.. toctree::
:maxdepth: 2
diff --git a/examples/ernie_image/model_inference/Ernie-Image-T2I.py b/examples/ernie_image/model_inference/Ernie-Image-T2I.py
new file mode 100644
index 0000000..25332cf
--- /dev/null
+++ b/examples/ernie_image/model_inference/Ernie-Image-T2I.py
@@ -0,0 +1,24 @@
+from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
+import torch
+
+pipe = ErnieImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device='cuda',
+ model_configs=[
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
+)
+
+image = pipe(
+ prompt="一只黑白相间的中华田园犬",
+ negative_prompt="",
+ height=1024,
+ width=1024,
+ seed=42,
+ num_inference_steps=50,
+ cfg_scale=4.0,
+)
+image.save("output.jpg")
diff --git a/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py b/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py
new file mode 100644
index 0000000..26b427d
--- /dev/null
+++ b/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py
@@ -0,0 +1,36 @@
+from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
+import torch
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+pipe = ErnieImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device='cuda',
+ model_configs=[
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+image = pipe(
+ prompt="一只黑白相间的中华田园犬",
+ negative_prompt="",
+ height=1024,
+ width=1024,
+ seed=42,
+ num_inference_steps=50,
+ cfg_scale=4.0,
+)
+image.save("output.jpg")
diff --git a/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh b/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh
new file mode 100644
index 0000000..550dde5
--- /dev/null
+++ b/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh
@@ -0,0 +1,17 @@
+# Dataset: data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I/
+
+accelerate launch --config_file examples/ernie_image/model_training/full/accelerate_config_zero3.yaml \
+ examples/ernie_image/model_training/train.py \
+ --dataset_base_path data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I \
+ --dataset_metadata_path data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I/metadata.csv \
+ --max_pixels 1048576 \
+ --dataset_repeat 50 \
+ --model_id_with_origin_paths "baidu/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors,baidu/ERNIE-Image:text_encoder/model.safetensors,baidu/ERNIE-Image:vae/diffusion_pytorch_model.safetensors" \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --output_path "./models/train/Ernie-Image-T2I_full" \
+ --trainable_models "dit" \
+ --use_gradient_checkpointing \
+ --dataset_num_workers 8 \
+ --find_unused_parameters
diff --git a/examples/ernie_image/model_training/full/accelerate_config_zero3.yaml b/examples/ernie_image/model_training/full/accelerate_config_zero3.yaml
new file mode 100644
index 0000000..e6a8d27
--- /dev/null
+++ b/examples/ernie_image/model_training/full/accelerate_config_zero3.yaml
@@ -0,0 +1,23 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ gradient_accumulation_steps: 1
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 3
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh b/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh
new file mode 100644
index 0000000..5c8732e
--- /dev/null
+++ b/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh
@@ -0,0 +1,19 @@
+# Dataset: data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ernie_image/Ernie-Image-T2I/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ernie_image/model_training/train.py \
+ --dataset_base_path data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I \
+ --dataset_metadata_path data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I/metadata.csv \
+ --max_pixels 1048576 \
+ --dataset_repeat 50 \
+ --model_id_with_origin_paths "baidu/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors,baidu/ERNIE-Image:text_encoder/model.safetensors,baidu/ERNIE-Image:vae/diffusion_pytorch_model.safetensors" \
+ --learning_rate 1e-4 \
+ --num_epochs 5 \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --output_path "./models/train/Ernie-Image-T2I_lora" \
+ --lora_base_model "dit" \
+ --lora_target_modules "to_q,to_k,to_v,to_out.0" \
+ --lora_rank 32 \
+ --use_gradient_checkpointing \
+ --dataset_num_workers 8 \
+ --find_unused_parameters
diff --git a/examples/ernie_image/model_training/train.py b/examples/ernie_image/model_training/train.py
new file mode 100644
index 0000000..5fa0bc8
--- /dev/null
+++ b/examples/ernie_image/model_training/train.py
@@ -0,0 +1,129 @@
+import torch, os, argparse, accelerate
+from diffsynth.core import UnifiedDataset
+from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
+from diffsynth.diffusion import *
+from diffsynth.core.data.operators import *
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
+class ErnieImageTrainingModule(DiffusionTrainingModule):
+ def __init__(
+ self,
+ model_paths=None, model_id_with_origin_paths=None,
+ tokenizer_path=None,
+ trainable_models=None,
+ lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
+ preset_lora_path=None, preset_lora_model=None,
+ use_gradient_checkpointing=True,
+ use_gradient_checkpointing_offload=False,
+ extra_inputs=None,
+ fp8_models=None,
+ offload_models=None,
+ device="cpu",
+ task="sft",
+ ):
+ super().__init__()
+ # Load models
+ model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
+ tokenizer_config = ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
+ self.pipe = ErnieImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
+ self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
+
+ # Training mode
+ self.switch_pipe_to_training_mode(
+ self.pipe, trainable_models,
+ lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
+ preset_lora_path, preset_lora_model,
+ task=task,
+ )
+
+ # Other configs
+ self.use_gradient_checkpointing = use_gradient_checkpointing
+ self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
+ self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
+ self.task = task
+ self.task_to_loss = {
+ "sft:data_process": lambda pipe, inputs_shared, inputs_posi, inputs_nega: (inputs_shared, inputs_posi, inputs_nega),
+ "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
+ "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
+ }
+
+ def get_pipeline_inputs(self, data):
+ inputs_posi = {"prompt": data["prompt"]}
+ inputs_nega = {"negative_prompt": ""}
+ inputs_shared = {
+ "input_image": data["image"],
+ "height": data["image"].size[1],
+ "width": data["image"].size[0],
+ "cfg_scale": 1,
+ "rand_device": self.pipe.device,
+ "use_gradient_checkpointing": self.use_gradient_checkpointing,
+ "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
+ }
+ inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
+ return inputs_shared, inputs_posi, inputs_nega
+
+ def forward(self, data, inputs=None):
+ if inputs is None:
+ inputs = self.get_pipeline_inputs(data)
+ inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
+ for unit in self.pipe.units:
+ inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
+ loss = self.task_to_loss[self.task](self.pipe, *inputs)
+ return loss
+
+
+def ernie_image_parser():
+ parser = argparse.ArgumentParser(description="ERNIE-Image training.")
+ parser = add_general_config(parser)
+ parser = add_image_size_config(parser)
+ parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
+ return parser
+
+
+if __name__ == "__main__":
+ parser = ernie_image_parser()
+ args = parser.parse_args()
+ accelerator = accelerate.Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
+ )
+ dataset = UnifiedDataset(
+ base_path=args.dataset_base_path,
+ metadata_path=args.dataset_metadata_path,
+ repeat=args.dataset_repeat,
+ data_file_keys=args.data_file_keys.split(","),
+ main_data_operator=lambda x: x,
+ special_operator_map={
+ "image": ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16),
+ },
+ )
+ model = ErnieImageTrainingModule(
+ model_paths=args.model_paths,
+ model_id_with_origin_paths=args.model_id_with_origin_paths,
+ tokenizer_path=args.tokenizer_path,
+ trainable_models=args.trainable_models,
+ lora_base_model=args.lora_base_model,
+ lora_target_modules=args.lora_target_modules,
+ lora_rank=args.lora_rank,
+ lora_checkpoint=args.lora_checkpoint,
+ preset_lora_path=args.preset_lora_path,
+ preset_lora_model=args.preset_lora_model,
+ use_gradient_checkpointing=args.use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
+ extra_inputs=args.extra_inputs,
+ fp8_models=args.fp8_models,
+ offload_models=args.offload_models,
+ task=args.task,
+ device=accelerator.device,
+ )
+ model_logger = ModelLogger(
+ args.output_path,
+ remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
+ )
+ launcher_map = {
+ "sft:data_process": launch_data_process_task,
+ "sft": launch_training_task,
+ "sft:train": launch_training_task,
+ }
+ launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
diff --git a/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py b/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py
new file mode 100644
index 0000000..4664126
--- /dev/null
+++ b/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py
@@ -0,0 +1,25 @@
+import torch
+from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
+from diffsynth.core import load_state_dict
+
+pipe = ErnieImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+)
+
+state_dict = load_state_dict("./models/train/Ernie-Image-T2I_full/epoch-1.safetensors")
+pipe.dit.load_state_dict(state_dict)
+
+image = pipe(
+ prompt="a professional photo of a cute dog",
+ seed=0,
+ num_inference_steps=50,
+ cfg_scale=4.0,
+)
+image.save("image_full.jpg")
+print("Full validation image saved to image_full.jpg")
diff --git a/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py b/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py
new file mode 100644
index 0000000..20f84eb
--- /dev/null
+++ b/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py
@@ -0,0 +1,25 @@
+import torch
+from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
+from diffsynth.core.loader.file import load_state_dict
+
+pipe = ErnieImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"),
+ ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+)
+
+lora_state_dict = load_state_dict("./models/train/Ernie-Image-T2I_lora/epoch-4.safetensors", torch_dtype=torch.bfloat16, device="cuda")
+pipe.load_lora(pipe.dit, state_dict=lora_state_dict, alpha=1.0)
+
+image = pipe(
+ prompt="a professional photo of a cute dog",
+ seed=0,
+ num_inference_steps=50,
+ cfg_scale=4.0,
+)
+image.save("image_lora.jpg")
+print("LoRA validation image saved to image_lora.jpg")