From 960d8c62c03a84ddbf7b91ee68fdb1c69b5ccd4f Mon Sep 17 00:00:00 2001 From: Hong Zhang <41229682+mi804@users.noreply.github.com> Date: Mon, 13 Apr 2026 14:57:10 +0800 Subject: [PATCH] Support ERNIE-Image (#1389) * ernie-image pipeline * ernie-image inference and training * style fix * ernie docs * lowvram * final style fix * pr-review * pr-fix round2 * set uniform training weight * fix * update lowvram docs --- README.md | 61 +++ README_zh.md | 60 +++ diffsynth/configs/model_configs.py | 18 +- .../configs/vram_management_module_maps.py | 12 + diffsynth/diffusion/flow_match.py | 22 +- diffsynth/models/ernie_image_dit.py | 362 ++++++++++++++++++ diffsynth/models/ernie_image_text_encoder.py | 76 ++++ diffsynth/pipelines/ernie_image.py | 265 +++++++++++++ .../ernie_image_text_encoder.py | 21 + docs/en/Model_Details/ERNIE-Image.md | 133 +++++++ docs/en/index.rst | 1 + docs/zh/Model_Details/ERNIE-Image.md | 133 +++++++ docs/zh/index.rst | 1 + .../model_inference/Ernie-Image-T2I.py | 24 ++ .../Ernie-Image-T2I.py | 36 ++ .../model_training/full/Ernie-Image-T2I.sh | 17 + .../full/accelerate_config_zero3.yaml | 23 ++ .../model_training/lora/Ernie-Image-T2I.sh | 19 + examples/ernie_image/model_training/train.py | 129 +++++++ .../validate_full/Ernie-Image-T2I.py | 25 ++ .../validate_lora/Ernie-Image-T2I.py | 25 ++ 21 files changed, 1461 insertions(+), 2 deletions(-) create mode 100644 diffsynth/models/ernie_image_dit.py create mode 100644 diffsynth/models/ernie_image_text_encoder.py create mode 100644 diffsynth/pipelines/ernie_image.py create mode 100644 diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py create mode 100644 docs/en/Model_Details/ERNIE-Image.md create mode 100644 docs/zh/Model_Details/ERNIE-Image.md create mode 100644 examples/ernie_image/model_inference/Ernie-Image-T2I.py create mode 100644 examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py create mode 100644 examples/ernie_image/model_training/full/Ernie-Image-T2I.sh create mode 100644 examples/ernie_image/model_training/full/accelerate_config_zero3.yaml create mode 100644 examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh create mode 100644 examples/ernie_image/model_training/train.py create mode 100644 examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py create mode 100644 examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py diff --git a/README.md b/README.md index 660b8de..1fd1ed0 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ We believe that a well-developed open-source code framework can lower the thresh > DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update. > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. + - **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available. - **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/). @@ -876,6 +877,66 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/) +#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md) + +
+ +Quick Start + +Running the following code will quickly load the [baidu/ERNIE-Image](https://www.modelscope.cn/models/baidu/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM. + +```python +from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ErnieImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device='cuda', + model_configs=[ + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image = pipe( + prompt="一只黑白相间的中华田园犬", + negative_prompt="", + height=1024, + width=1024, + seed=42, + num_inference_steps=50, + cfg_scale=4.0, +) +image.save("output.jpg") +``` + +
+ +
+ +Examples + +Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/) + +| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | +|-|-|-|-|-|-|-| +|[baidu/ERNIE-Image: T2I](https://www.modelscope.cn/models/baidu/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py)| + +
+ ## Innovative Achievements DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements. diff --git a/README_zh.md b/README_zh.md index 92e230b..29e293a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -877,6 +877,66 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/) +#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md) + +
+ +快速开始 + +运行以下代码可以快速加载 [baidu/ERNIE-Image](https://www.modelscope.cn/models/baidu/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。 + +```python +from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ErnieImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device='cuda', + model_configs=[ + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image = pipe( + prompt="一只黑白相间的中华田园犬", + negative_prompt="", + height=1024, + width=1024, + seed=42, + num_inference_steps=50, + cfg_scale=4.0, +) +image.save("output.jpg") +``` + +
+ +
+ +示例代码 + +ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/) + +| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 | +|-|-|-|-|-|-|-| +|[baidu/ERNIE-Image: T2I](https://www.modelscope.cn/models/baidu/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py)| + +
+ ## 创新成果 DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。 diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 4222202..7428962 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -541,6 +541,22 @@ flux2_series = [ }, ] +ernie_image_series = [ + { + # Example: ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "584c13713849f1af4e03d5f1858b8b7b", + "model_name": "ernie_image_dit", + "model_class": "diffsynth.models.ernie_image_dit.ErnieImageDiT", + }, + { + # Example: ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors") + "model_hash": "404ed9f40796a38dd34c1620f1920207", + "model_name": "ernie_image_text_encoder", + "model_class": "diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ernie_image_text_encoder.ErnieImageTextEncoderStateDictConverter", + }, +] + z_image_series = [ { # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors") @@ -884,4 +900,4 @@ mova_series = [ "model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge", }, ] -MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index de27689..ba7ed77 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -267,6 +267,18 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule", }, + "diffsynth.models.ernie_image_dit.ErnieImageDiT": { + "diffsynth.models.ernie_image_dit.ErnieImageRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, } def QwenImageTextEncoder_Module_Map_Updater(): diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index 208fb1e..1ac5c49 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -4,7 +4,7 @@ from typing_extensions import Literal class FlowMatchScheduler(): - def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"): + def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image"] = "FLUX.1"): self.set_timesteps_fn = { "FLUX.1": FlowMatchScheduler.set_timesteps_flux, "Wan": FlowMatchScheduler.set_timesteps_wan, @@ -13,6 +13,7 @@ class FlowMatchScheduler(): "Z-Image": FlowMatchScheduler.set_timesteps_z_image, "LTX-2": FlowMatchScheduler.set_timesteps_ltx2, "Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning, + "ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image, }.get(template, FlowMatchScheduler.set_timesteps_flux) self.num_train_timesteps = 1000 @@ -129,6 +130,15 @@ class FlowMatchScheduler(): timesteps = sigmas * num_train_timesteps return sigmas, timesteps + @staticmethod + def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0): + """ERNIE-Image scheduler: pure linear sigmas from 1.0 to 0.0, no shift.""" + num_train_timesteps = 1000 + sigma_start = denoising_strength + sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1] + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + @staticmethod def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None): sigma_min = 0.0 @@ -175,6 +185,9 @@ class FlowMatchScheduler(): return sigmas, timesteps def set_training_weight(self): + if self.set_timesteps_fn == FlowMatchScheduler.set_timesteps_ernie_image: + self.set_uniform_training_weight() + return steps = 1000 x = self.timesteps y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2) @@ -185,6 +198,13 @@ class FlowMatchScheduler(): bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps) bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1] self.linear_timesteps_weights = bsmntw_weighing + + def set_uniform_training_weight(self): + """Assign equal weight to every timestep, suitable for linear schedulers like ERNIE-Image.""" + steps = 1000 + num_steps = len(self.timesteps) + uniform_weight = torch.full((num_steps,), steps / num_steps, dtype=self.timesteps.dtype) + self.linear_timesteps_weights = uniform_weight def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs): self.sigmas, self.timesteps = self.set_timesteps_fn( diff --git a/diffsynth/models/ernie_image_dit.py b/diffsynth/models/ernie_image_dit.py new file mode 100644 index 0000000..fd0e022 --- /dev/null +++ b/diffsynth/models/ernie_image_dit.py @@ -0,0 +1,362 @@ +""" +Ernie-Image DiT for DiffSynth-Studio. + +Refactored from diffusers ErnieImageTransformer2DModel to use DiffSynth core modules. +Default parameters from actual checkpoint config.json (baidu/ERNIE-Image transformer). +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple + +from ..core.attention import attention_forward +from ..core.gradient import gradient_checkpoint_forward +from .flux2_dit import Timesteps, TimestepEmbedding + + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta ** scale) + out = torch.einsum("...n,d->...nd", pos, omega) + return out.float() + + +class ErnieImageEmbedND3(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = list(axes_dim) + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) + emb = emb.unsqueeze(2) + return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) + + +class ErnieImagePatchEmbedDynamic(nn.Module): + def __init__(self, in_channels: int, embed_dim: int, patch_size: int): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + batch_size, dim, height, width = x.shape + return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous() + + +class ErnieImageSingleStreamAttnProcessor: + def __call__( + self, + attn: "ErnieImageAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + rot_dim = freqs_cis.shape[-1] + x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:] + cos_ = torch.cos(freqs_cis).to(x.dtype) + sin_ = torch.sin(freqs_cis).to(x.dtype) + x1, x2 = x.chunk(2, dim=-1) + x_rotated = torch.cat((-x2, x1), dim=-1) + return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1) + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + hidden_states = attention_forward( + query, key, value, + q_pattern="b s n d", + k_pattern="b s n d", + v_pattern="b s n d", + out_pattern="b s n d", + attn_mask=attention_mask, + ) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + output = attn.to_out[0](hidden_states) + + return output + + +class ErnieImageAttention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + qk_norm: str = "rms_norm", + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "rms_norm": + self.norm_q = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'rms_norm'." + ) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + + self.processor = ErnieImageSingleStreamAttnProcessor() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.processor(self, hidden_states, attention_mask, image_rotary_emb) + + +class ErnieImageFeedForward(nn.Module): + def __init__(self, hidden_size: int, ffn_hidden_size: int): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x))) + + +class ErnieImageRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + hidden_states = hidden_states * self.weight + return hidden_states.to(input_dtype) + + +class ErnieImageSharedAdaLNBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + ffn_hidden_size: int, + eps: float = 1e-6, + qk_layernorm: bool = True, + ): + super().__init__() + self.adaLN_sa_ln = ErnieImageRMSNorm(hidden_size, eps=eps) + self.self_attention = ErnieImageAttention( + query_dim=hidden_size, + dim_head=hidden_size // num_heads, + heads=num_heads, + qk_norm="rms_norm" if qk_layernorm else None, + eps=eps, + bias=False, + out_bias=False, + ) + self.adaLN_mlp_ln = ErnieImageRMSNorm(hidden_size, eps=eps) + self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size) + + def forward( + self, + x: torch.Tensor, + rotary_pos_emb: torch.Tensor, + temb: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb + residual = x + x = self.adaLN_sa_ln(x) + x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype) + x_bsh = x.permute(1, 0, 2) + attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) + attn_out = attn_out.permute(1, 0, 2) + x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype) + residual = x + x = self.adaLN_mlp_ln(x) + x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype) + return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype) + + +class ErnieImageAdaLNContinuous(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps) + self.linear = nn.Linear(hidden_size, hidden_size * 2) + + def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + scale, shift = self.linear(conditioning).chunk(2, dim=-1) + x = self.norm(x) + x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0) + return x + + +class ErnieImageDiT(nn.Module): + """ + Ernie-Image DiT model for DiffSynth-Studio. + + Architecture: SharedAdaLN + RoPE 3D + Joint Image-Text Attention. + Internal format: [S, B, H] for transformer blocks, [B, S, H] for attention. + """ + + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_layers: int = 36, + ffn_hidden_size: int = 12288, + in_channels: int = 128, + out_channels: int = 128, + patch_size: int = 1, + text_in_dim: int = 3072, + rope_theta: int = 256, + rope_axes_dim: Tuple[int, int, int] = (32, 48, 48), + eps: float = 1e-6, + qk_layernorm: bool = True, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.num_layers = num_layers + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.text_in_dim = text_in_dim + + self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size) + self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None + self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0) + self.time_embedding = TimestepEmbedding(hidden_size, hidden_size) + self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size)) + nn.init.zeros_(self.adaLN_modulation[-1].weight) + nn.init.zeros_(self.adaLN_modulation[-1].bias) + self.layers = nn.ModuleList([ + ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm) + for _ in range(num_layers) + ]) + self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps) + self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels) + nn.init.zeros_(self.final_linear.weight) + nn.init.zeros_(self.final_linear.bias) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + text_bth: torch.Tensor, + text_lens: torch.Tensor, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + ) -> torch.Tensor: + device, dtype = hidden_states.device, hidden_states.dtype + B, C, H, W = hidden_states.shape + p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size + N_img = Hp * Wp + + img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous() + + if self.text_proj is not None and text_bth.numel() > 0: + text_bth = self.text_proj(text_bth) + Tmax = text_bth.shape[1] + text_sbh = text_bth.transpose(0, 1).contiguous() + + x = torch.cat([img_sbh, text_sbh], dim=0) + S = x.shape[0] + + text_ids = torch.cat([ + torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1), + torch.zeros((B, Tmax, 2), device=device) + ], dim=-1) if Tmax > 0 else torch.zeros((B, 0, 3), device=device) + grid_yx = torch.stack( + torch.meshgrid(torch.arange(Hp, device=device, dtype=torch.float32), + torch.arange(Wp, device=device, dtype=torch.float32), indexing="ij"), + dim=-1 + ).reshape(-1, 2) + image_ids = torch.cat([ + text_lens.float().view(B, 1, 1).expand(-1, N_img, -1), + grid_yx.view(1, N_img, 2).expand(B, -1, -1) + ], dim=-1) + rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) + + valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool) + attention_mask = torch.cat([ + torch.ones((B, N_img), device=device, dtype=torch.bool), + valid_text + ], dim=1)[:, None, None, :] + + sample = self.time_proj(timestep.to(dtype)) + sample = sample.to(self.time_embedding.linear_1.weight.dtype) + c = self.time_embedding(sample) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ + t.unsqueeze(0).expand(S, -1, -1).contiguous() + for t in self.adaLN_modulation(c).chunk(6, dim=-1) + ] + + for layer in self.layers: + temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] + if torch.is_grad_enabled() and use_gradient_checkpointing: + x = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, + rotary_pos_emb, + temb, + attention_mask, + ) + else: + x = layer(x, rotary_pos_emb, temb, attention_mask) + + x = self.final_norm(x, c).type_as(x) + patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous() + output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W) + + return output diff --git a/diffsynth/models/ernie_image_text_encoder.py b/diffsynth/models/ernie_image_text_encoder.py new file mode 100644 index 0000000..17460b2 --- /dev/null +++ b/diffsynth/models/ernie_image_text_encoder.py @@ -0,0 +1,76 @@ +""" +Ernie-Image TextEncoder for DiffSynth-Studio. + +Wraps transformers Ministral3Model to output text embeddings. +Pattern: lazy import + manual config dict + torch.nn.Module wrapper. +Only loads the text (language) model, ignoring vision components. +""" + +import torch + + +class ErnieImageTextEncoder(torch.nn.Module): + """ + Text encoder using Ministral3Model (transformers). + Only the text_config portion of the full Mistral3Model checkpoint. + Uses the base model (no lm_head) since the checkpoint only has embeddings. + """ + + def __init__(self): + super().__init__() + from transformers import Ministral3Config, Ministral3Model + + text_config = { + "attention_dropout": 0.0, + "bos_token_id": 1, + "dtype": "bfloat16", + "eos_token_id": 2, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 9216, + "max_position_embeddings": 262144, + "model_type": "ministral3", + "num_attention_heads": 32, + "num_hidden_layers": 26, + "num_key_value_heads": 8, + "pad_token_id": 11, + "rms_norm_eps": 1e-05, + "rope_parameters": { + "beta_fast": 32.0, + "beta_slow": 1.0, + "factor": 16.0, + "llama_4_scaling_beta": 0.1, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 16384, + "rope_theta": 1000000.0, + "rope_type": "yarn", + "type": "yarn", + }, + "sliding_window": None, + "tie_word_embeddings": True, + "use_cache": True, + "vocab_size": 131072, + } + config = Ministral3Config(**text_config) + self.model = Ministral3Model(config) + self.config = config + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + **kwargs, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=True, + return_dict=True, + **kwargs, + ) + return (outputs.hidden_states,) diff --git a/diffsynth/pipelines/ernie_image.py b/diffsynth/pipelines/ernie_image.py new file mode 100644 index 0000000..a2b411d --- /dev/null +++ b/diffsynth/pipelines/ernie_image.py @@ -0,0 +1,265 @@ +""" +ERNIE-Image Text-to-Image Pipeline for DiffSynth-Studio. + +Architecture: SharedAdaLN DiT + RoPE 3D + Joint Image-Text Attention. +""" + +import torch +from typing import Union, Optional +from tqdm import tqdm +from transformers import AutoTokenizer + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit +from ..models.ernie_image_text_encoder import ErnieImageTextEncoder +from ..models.ernie_image_dit import ErnieImageDiT +from ..models.flux2_vae import Flux2VAE + + +class ErnieImagePipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("ERNIE-Image") + self.text_encoder: ErnieImageTextEncoder = None + self.dit: ErnieImageDiT = None + self.vae: Flux2VAE = None + self.tokenizer: AutoTokenizer = None + + self.in_iteration_models = ("dit",) + self.units = [ + ErnieImageUnit_ShapeChecker(), + ErnieImageUnit_PromptEmbedder(), + ErnieImageUnit_NoiseInitializer(), + ErnieImageUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_ernie_image + self.compilable_models = ["dit"] + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"), + vram_limit: float = None, + ): + pipe = ErnieImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + pipe.text_encoder = model_pool.fetch_model("ernie_image_text_encoder") + pipe.dit = model_pool.fetch_model("ernie_image_dit") + pipe.vae = model_pool.fetch_model("flux2_vae") + + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 4.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cuda", + # Steps + num_inference_steps: int = 50, + # Progress bar + progress_bar_cmd=tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps=num_inference_steps) + + # Parameters + inputs_posi = {"prompt": prompt} + inputs_nega = {"negative_prompt": negative_prompt} + inputs_shared = { + "height": height, "width": width, "seed": seed, + "cfg_scale": cfg_scale, "num_inference_steps": num_inference_steps, + "rand_device": rand_device, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + latents = inputs_shared["latents"] + image = self.vae.decode(latents) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + return image + + +class ErnieImageUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: ErnieImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + +class ErnieImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds", "prompt_embeds_mask"), + onload_model_names=("text_encoder",) + ) + + def encode_prompt(self, pipe: ErnieImagePipeline, prompt): + if isinstance(prompt, str): + prompt = [prompt] + + text_hiddens = [] + text_lens_list = [] + for p in prompt: + ids = pipe.tokenizer( + p, + add_special_tokens=True, + truncation=True, + padding=False, + )["input_ids"] + + if len(ids) == 0: + if pipe.tokenizer.bos_token_id is not None: + ids = [pipe.tokenizer.bos_token_id] + else: + ids = [0] + + input_ids = torch.tensor([ids], device=pipe.device) + outputs = pipe.text_encoder( + input_ids=input_ids, + ) + # Text encoder returns tuple of (hidden_states_tuple,) where each layer's hidden state is included + all_hidden_states = outputs[0] + hidden = all_hidden_states[-2][0] # [T, H] - second to last layer + text_hiddens.append(hidden) + text_lens_list.append(hidden.shape[0]) + + # Pad to uniform length + if len(text_hiddens) == 0: + text_in_dim = pipe.text_encoder.config.hidden_size if hasattr(pipe.text_encoder, 'config') else 3072 + return { + "prompt_embeds": torch.zeros((0, 0, text_in_dim), device=pipe.device, dtype=pipe.torch_dtype), + "prompt_embeds_mask": torch.zeros((0,), device=pipe.device, dtype=torch.long), + } + + normalized = [th.to(pipe.device).to(pipe.torch_dtype) for th in text_hiddens] + text_lens = torch.tensor([t.shape[0] for t in normalized], device=pipe.device, dtype=torch.long) + Tmax = int(text_lens.max().item()) + text_in_dim = normalized[0].shape[1] + text_bth = torch.zeros((len(normalized), Tmax, text_in_dim), device=pipe.device, dtype=pipe.torch_dtype) + for i, t in enumerate(normalized): + text_bth[i, :t.shape[0], :] = t + + return {"prompt_embeds": text_bth, "prompt_embeds_mask": text_lens} + + def process(self, pipe: ErnieImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + if pipe.text_encoder is not None: + return self.encode_prompt(pipe, prompt) + return {} + + +class ErnieImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: ErnieImagePipeline, height, width, seed, rand_device): + latent_h = height // pipe.height_division_factor + latent_w = width // pipe.width_division_factor + latent_channels = pipe.dit.in_channels + + # Use pipeline device if rand_device is not specified + if rand_device is None: + rand_device = str(pipe.device) + + noise = pipe.generate_noise( + (1, latent_channels, latent_h, latent_w), + seed=seed, + rand_device=rand_device, + rand_torch_dtype=pipe.torch_dtype, + ) + return {"noise": noise} + + +class ErnieImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: ErnieImagePipeline, input_image, noise): + if input_image is None: + # T2I path: use noise directly as initial latents + return {"latents": noise, "input_latents": None} + + # I2I path: VAE encode input image + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae.encode(image) + + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + # In inference mode, add noise to encoded latents + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} + + +def model_fn_ernie_image( + dit: ErnieImageDiT, + latents=None, + timestep=None, + prompt_embeds=None, + prompt_embeds_mask=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + output = dit( + hidden_states=latents, + timestep=timestep, + text_bth=prompt_embeds, + text_lens=prompt_embeds_mask, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + return output diff --git a/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py b/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py new file mode 100644 index 0000000..f444d76 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py @@ -0,0 +1,21 @@ +def ErnieImageTextEncoderStateDictConverter(state_dict): + """ + Maps checkpoint keys from multimodal Mistral3Model format + to text-only Ministral3Model format. + + Checkpoint keys (Mistral3Model): + language_model.model.layers.0.input_layernorm.weight + language_model.model.norm.weight + + Model keys (ErnieImageTextEncoder → self.model = Ministral3Model): + model.layers.0.input_layernorm.weight + model.norm.weight + + Mapping: language_model. → model. + """ + new_state_dict = {} + for key in state_dict: + if key.startswith("language_model.model."): + new_key = key.replace("language_model.model.", "model.", 1) + new_state_dict[new_key] = state_dict[key] + return new_state_dict diff --git a/docs/en/Model_Details/ERNIE-Image.md b/docs/en/Model_Details/ERNIE-Image.md new file mode 100644 index 0000000..601b26c --- /dev/null +++ b/docs/en/Model_Details/ERNIE-Image.md @@ -0,0 +1,133 @@ +# ERNIE-Image + +ERNIE-Image is a powerful image generation model with 8B parameters developed by Baidu, featuring a compact and efficient architecture with strong instruction-following capability. Based on an 8B DiT backbone, it delivers performance comparable to larger (20B+) models in certain scenarios while maintaining parameter efficiency. It offers reliable performance in instruction understanding and execution, text generation (English/Chinese/Japanese), and overall stability. + +## Installation + +Before performing model inference and training, please install DiffSynth-Studio first. + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md). + +## Quick Start + +Running the following code will load the [baidu/ERNIE-Image](https://www.modelscope.cn/models/baidu/ERNIE-Image) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 3G VRAM. + +```python +from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ErnieImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device='cuda', + model_configs=[ + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image = pipe( + prompt="一只黑白相间的中华田园犬", + negative_prompt="", + height=1024, + width=1024, + seed=42, + num_inference_steps=50, + cfg_scale=4.0, +) +image.save("output.jpg") +``` + +## Model Overview + +|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation| +|-|-|-|-|-|-|-| +|[baidu/ERNIE-Image: T2I](https://www.modelscope.cn/models/baidu/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py)| + +## Model Inference + +The model is loaded via `ErnieImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details. + +The input parameters for `ErnieImagePipeline` inference include: + +* `prompt`: The prompt describing the content to appear in the image. +* `negative_prompt`: The negative prompt describing what should not appear in the image, default value is `""`. +* `cfg_scale`: Classifier-free guidance parameter, default value is 4.0. +* `height`: Image height, must be a multiple of 16, default value is 1024. +* `width`: Image width, must be a multiple of 16, default value is 1024. +* `seed`: Random seed. Default is `None`, meaning completely random. +* `rand_device`: The computing device for generating random Gaussian noise matrices, default is `"cuda"`. When set to `cuda`, different GPUs will produce different results. +* `num_inference_steps`: Number of inference steps, default value is 50. + +If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low-VRAM configurations for each model in the "Model Overview" table above. + +## Model Training + +ERNIE-Image series models are trained uniformly via [`examples/ernie_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/train.py). The script parameters include: + +* General Training Parameters + * Dataset Configuration + * `--dataset_base_path`: Root directory of the dataset. + * `--dataset_metadata_path`: Path to the dataset metadata file. + * `--dataset_repeat`: Number of dataset repeats per epoch. + * `--dataset_num_workers`: Number of processes per DataLoader. + * `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`. + * Model Loading Configuration + * `--model_paths`: Paths to load models from, in JSON format. + * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"baidu/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors"`, separated by commas. + * `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`. + * `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients. + * Basic Training Configuration + * `--learning_rate`: Learning rate. + * `--num_epochs`: Number of epochs. + * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`. + * `--find_unused_parameters`: Whether unused parameters exist in DDP training. + * `--weight_decay`: Weight decay magnitude. + * `--task`: Training task, defaults to `sft`. + * Output Configuration + * `--output_path`: Path to save the model. + * `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict. + * `--save_steps`: Interval in training steps to save the model. + * LoRA Configuration + * `--lora_base_model`: Which model to add LoRA to. + * `--lora_target_modules`: Which layers to add LoRA to. + * `--lora_rank`: Rank of LoRA. + * `--lora_checkpoint`: Path to LoRA checkpoint. + * `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training. + * `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`. + * Gradient Configuration + * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing. + * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory. + * `--gradient_accumulation_steps`: Number of gradient accumulation steps. + * Resolution Configuration + * `--height`: Height of the image. Leave empty to enable dynamic resolution. + * `--width`: Width of the image. Leave empty to enable dynamic resolution. + * `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution. +* ERNIE-Image Specific Parameters + * `--tokenizer_path`: Path to the tokenizer, leave empty to auto-download from remote. + +We provide an example image dataset for testing, which can be downloaded with the following command: + +```shell +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset +``` + +We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/). diff --git a/docs/en/index.rst b/docs/en/index.rst index 4b933ca..16c47ee 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -29,6 +29,7 @@ Welcome to DiffSynth-Studio's Documentation Model_Details/Z-Image Model_Details/Anima Model_Details/LTX-2 + Model_Details/ERNIE-Image .. toctree:: :maxdepth: 2 diff --git a/docs/zh/Model_Details/ERNIE-Image.md b/docs/zh/Model_Details/ERNIE-Image.md new file mode 100644 index 0000000..f7acbe7 --- /dev/null +++ b/docs/zh/Model_Details/ERNIE-Image.md @@ -0,0 +1,133 @@ +# ERNIE-Image + +ERNIE-Image 是百度推出的拥有 8B 参数的图像生成模型,具有紧凑高效的架构和出色的指令跟随能力。基于 8B DiT 主干网络,其在某些场景下的性能可与 20B 以上的更大模型相媲美,同时保持了良好的参数效率。该模型在指令理解与执行、文本生成(如英文/中文/日文)以及整体稳定性方面提供了较为可靠的表现。 + +## 安装 + +在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。 + +## 快速开始 + +运行以下代码可以快速加载 [baidu/ERNIE-Image](https://www.modelscope.cn/models/baidu/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。 + +```python +from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ErnieImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device='cuda', + model_configs=[ + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image = pipe( + prompt="一只黑白相间的中华田园犬", + negative_prompt="", + height=1024, + width=1024, + seed=42, + num_inference_steps=50, + cfg_scale=4.0, +) +image.save("output.jpg") +``` + +## 模型总览 + +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[baidu/ERNIE-Image: T2I](https://www.modelscope.cn/models/baidu/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py)| + +## 模型推理 + +模型通过 `ErnieImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。 + +`ErnieImagePipeline` 推理的输入参数包括: + +* `prompt`: 提示词,描述画面中出现的内容。 +* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。 +* `cfg_scale`: Classifier-free guidance 的参数,默认值为 4.0。 +* `height`: 图像高度,需保证高度为 16 的倍数,默认值为 1024。 +* `width`: 图像宽度,需保证宽度为 16 的倍数,默认值为 1024。 +* `seed`: 随机种子。默认为 `None`,即完全随机。 +* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cuda"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 +* `num_inference_steps`: 推理步数,默认值为 50。 + +如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 + +## 模型训练 + +ERNIE-Image 系列模型统一通过 [`examples/ernie_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/train.py) 进行训练,脚本的参数包括: + +* 通用训练参数 + * 数据集基础配置 + * `--dataset_base_path`: 数据集的根目录。 + * `--dataset_metadata_path`: 数据集的元数据文件路径。 + * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloader 的进程数量。 + * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。 + * 模型加载配置 + * `--model_paths`: 要加载的模型路径。JSON 格式。 + * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"baidu/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors"`。用逗号分隔。 + * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。 + * `--fp8_models`:以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。 + * 训练基础配置 + * `--learning_rate`: 学习率。 + * `--num_epochs`: 轮数(Epoch)。 + * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。 + * `--weight_decay`:权重衰减大小。 + * `--task`: 训练任务,默认为 `sft`。 + * 输出配置 + * `--output_path`: 模型保存路径。 + * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。 + * `--save_steps`: 保存模型的训练步数间隔。 + * LoRA 配置 + * `--lora_base_model`: LoRA 添加到哪个模型上。 + * `--lora_target_modules`: LoRA 添加到哪些层上。 + * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。 + * `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。 + * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。 + * 梯度配置 + * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。 + * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 + * `--gradient_accumulation_steps`: 梯度累积步数。 + * 分辨率配置 + * `--height`: 图像的高度。留空启用动态分辨率。 + * `--width`: 图像的宽度。留空启用动态分辨率。 + * `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。 +* ERNIE-Image 专有参数 + * `--tokenizer_path`: tokenizer 的路径,留空则自动从远程下载。 + +我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集: + +```shell +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset +``` + +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。 diff --git a/docs/zh/index.rst b/docs/zh/index.rst index 42256b3..05fb09d 100644 --- a/docs/zh/index.rst +++ b/docs/zh/index.rst @@ -29,6 +29,7 @@ Model_Details/Z-Image Model_Details/Anima Model_Details/LTX-2 + Model_Details/ERNIE-Image .. toctree:: :maxdepth: 2 diff --git a/examples/ernie_image/model_inference/Ernie-Image-T2I.py b/examples/ernie_image/model_inference/Ernie-Image-T2I.py new file mode 100644 index 0000000..25332cf --- /dev/null +++ b/examples/ernie_image/model_inference/Ernie-Image-T2I.py @@ -0,0 +1,24 @@ +from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig +import torch + +pipe = ErnieImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device='cuda', + model_configs=[ + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"), +) + +image = pipe( + prompt="一只黑白相间的中华田园犬", + negative_prompt="", + height=1024, + width=1024, + seed=42, + num_inference_steps=50, + cfg_scale=4.0, +) +image.save("output.jpg") diff --git a/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py b/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py new file mode 100644 index 0000000..26b427d --- /dev/null +++ b/examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py @@ -0,0 +1,36 @@ +from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +pipe = ErnieImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device='cuda', + model_configs=[ + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image = pipe( + prompt="一只黑白相间的中华田园犬", + negative_prompt="", + height=1024, + width=1024, + seed=42, + num_inference_steps=50, + cfg_scale=4.0, +) +image.save("output.jpg") diff --git a/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh b/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh new file mode 100644 index 0000000..550dde5 --- /dev/null +++ b/examples/ernie_image/model_training/full/Ernie-Image-T2I.sh @@ -0,0 +1,17 @@ +# Dataset: data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I/ + +accelerate launch --config_file examples/ernie_image/model_training/full/accelerate_config_zero3.yaml \ + examples/ernie_image/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I \ + --dataset_metadata_path data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "baidu/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors,baidu/ERNIE-Image:text_encoder/model.safetensors,baidu/ERNIE-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Ernie-Image-T2I_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/ernie_image/model_training/full/accelerate_config_zero3.yaml b/examples/ernie_image/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 0000000..e6a8d27 --- /dev/null +++ b/examples/ernie_image/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh b/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh new file mode 100644 index 0000000..5c8732e --- /dev/null +++ b/examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh @@ -0,0 +1,19 @@ +# Dataset: data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I/ +# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ernie_image/Ernie-Image-T2I/*" --local_dir ./data/diffsynth_example_dataset + +accelerate launch examples/ernie_image/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I \ + --dataset_metadata_path data/diffsynth_example_dataset/ernie_image/Ernie-Image-T2I/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "baidu/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors,baidu/ERNIE-Image:text_encoder/model.safetensors,baidu/ERNIE-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Ernie-Image-T2I_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/ernie_image/model_training/train.py b/examples/ernie_image/model_training/train.py new file mode 100644 index 0000000..5fa0bc8 --- /dev/null +++ b/examples/ernie_image/model_training/train.py @@ -0,0 +1,129 @@ +import torch, os, argparse, accelerate +from diffsynth.core import UnifiedDataset +from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig +from diffsynth.diffusion import * +from diffsynth.core.data.operators import * +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class ErnieImageTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + fp8_models=None, + offload_models=None, + device="cpu", + task="sft", + ): + super().__init__() + # Load models + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) + tokenizer_config = ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) + self.pipe = ErnieImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) + self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) + + # Training mode + self.switch_pipe_to_training_mode( + self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, + preset_lora_path, preset_lora_model, + task=task, + ) + + # Other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.task = task + self.task_to_loss = { + "sft:data_process": lambda pipe, inputs_shared, inputs_posi, inputs_nega: (inputs_shared, inputs_posi, inputs_nega), + "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + } + + def get_pipeline_inputs(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + "cfg_scale": 1, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + } + inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) + return inputs_shared, inputs_posi, inputs_nega + + def forward(self, data, inputs=None): + if inputs is None: + inputs = self.get_pipeline_inputs(data) + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + for unit in self.pipe.units: + inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) + loss = self.task_to_loss[self.task](self.pipe, *inputs) + return loss + + +def ernie_image_parser(): + parser = argparse.ArgumentParser(description="ERNIE-Image training.") + parser = add_general_config(parser) + parser = add_image_size_config(parser) + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + return parser + + +if __name__ == "__main__": + parser = ernie_image_parser() + args = parser.parse_args() + accelerator = accelerate.Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], + ) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=lambda x: x, + special_operator_map={ + "image": ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16), + }, + ) + model = ErnieImageTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_path=args.tokenizer_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + preset_lora_path=args.preset_lora_path, + preset_lora_model=args.preset_lora_model, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + fp8_models=args.fp8_models, + offload_models=args.offload_models, + task=args.task, + device=accelerator.device, + ) + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + ) + launcher_map = { + "sft:data_process": launch_data_process_task, + "sft": launch_training_task, + "sft:train": launch_training_task, + } + launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) diff --git a/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py b/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py new file mode 100644 index 0000000..4664126 --- /dev/null +++ b/examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py @@ -0,0 +1,25 @@ +import torch +from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig +from diffsynth.core import load_state_dict + +pipe = ErnieImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], +) + +state_dict = load_state_dict("./models/train/Ernie-Image-T2I_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) + +image = pipe( + prompt="a professional photo of a cute dog", + seed=0, + num_inference_steps=50, + cfg_scale=4.0, +) +image.save("image_full.jpg") +print("Full validation image saved to image_full.jpg") diff --git a/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py b/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py new file mode 100644 index 0000000..20f84eb --- /dev/null +++ b/examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py @@ -0,0 +1,25 @@ +import torch +from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig +from diffsynth.core.loader.file import load_state_dict + +pipe = ErnieImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], +) + +lora_state_dict = load_state_dict("./models/train/Ernie-Image-T2I_lora/epoch-4.safetensors", torch_dtype=torch.bfloat16, device="cuda") +pipe.load_lora(pipe.dit, state_dict=lora_state_dict, alpha=1.0) + +image = pipe( + prompt="a professional photo of a cute dog", + seed=0, + num_inference_steps=50, + cfg_scale=4.0, +) +image.save("image_lora.jpg") +print("LoRA validation image saved to image_lora.jpg")