Files
DiffSynth-Studio/diffsynth/models/ernie_image_dit.py
Hong Zhang 960d8c62c0 Support ERNIE-Image (#1389)
* ernie-image pipeline

* ernie-image inference and training

* style fix

* ernie docs

* lowvram

* final style fix

* pr-review

* pr-fix round2

* set uniform training weight

* fix

* update lowvram docs
2026-04-13 14:57:10 +08:00

363 lines
14 KiB
Python

"""
Ernie-Image DiT for DiffSynth-Studio.
Refactored from diffusers ErnieImageTransformer2DModel to use DiffSynth core modules.
Default parameters from actual checkpoint config.json (baidu/ERNIE-Image transformer).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from ..core.attention import attention_forward
from ..core.gradient import gradient_checkpoint_forward
from .flux2_dit import Timesteps, TimestepEmbedding
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta ** scale)
out = torch.einsum("...n,d->...nd", pos, omega)
return out.float()
class ErnieImageEmbedND3(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = list(axes_dim)
def forward(self, ids: torch.Tensor) -> torch.Tensor:
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
emb = emb.unsqueeze(2)
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1)
class ErnieImagePatchEmbedDynamic(nn.Module):
def __init__(self, in_channels: int, embed_dim: int, patch_size: int):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
batch_size, dim, height, width = x.shape
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
class ErnieImageSingleStreamAttnProcessor:
def __call__(
self,
attn: "ErnieImageAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
rot_dim = freqs_cis.shape[-1]
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
cos_ = torch.cos(freqs_cis).to(x.dtype)
sin_ = torch.sin(freqs_cis).to(x.dtype)
x1, x2 = x.chunk(2, dim=-1)
x_rotated = torch.cat((-x2, x1), dim=-1)
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask[:, None, None, :]
hidden_states = attention_forward(
query, key, value,
q_pattern="b s n d",
k_pattern="b s n d",
v_pattern="b s n d",
out_pattern="b s n d",
attn_mask=attention_mask,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
output = attn.to_out[0](hidden_states)
return output
class ErnieImageAttention(nn.Module):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: str = "rms_norm",
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
elementwise_affine: bool = True,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.out_dim = out_dim if out_dim is not None else query_dim
self.heads = out_dim // dim_head if out_dim is not None else heads
self.use_bias = bias
self.dropout = dropout
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
if qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
elif qk_norm == "rms_norm":
self.norm_q = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
else:
raise ValueError(
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'rms_norm'."
)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.processor = ErnieImageSingleStreamAttnProcessor()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.processor(self, hidden_states, attention_mask, image_rotary_emb)
class ErnieImageFeedForward(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
class ErnieImageRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states * self.weight
return hidden_states.to(input_dtype)
class ErnieImageSharedAdaLNBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
ffn_hidden_size: int,
eps: float = 1e-6,
qk_layernorm: bool = True,
):
super().__init__()
self.adaLN_sa_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
self.self_attention = ErnieImageAttention(
query_dim=hidden_size,
dim_head=hidden_size // num_heads,
heads=num_heads,
qk_norm="rms_norm" if qk_layernorm else None,
eps=eps,
bias=False,
out_bias=False,
)
self.adaLN_mlp_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size)
def forward(
self,
x: torch.Tensor,
rotary_pos_emb: torch.Tensor,
temb: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
residual = x
x = self.adaLN_sa_ln(x)
x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
x_bsh = x.permute(1, 0, 2)
attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
attn_out = attn_out.permute(1, 0, 2)
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
residual = x
x = self.adaLN_mlp_ln(x)
x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype)
class ErnieImageAdaLNContinuous(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps)
self.linear = nn.Linear(hidden_size, hidden_size * 2)
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
x = self.norm(x)
x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
return x
class ErnieImageDiT(nn.Module):
"""
Ernie-Image DiT model for DiffSynth-Studio.
Architecture: SharedAdaLN + RoPE 3D + Joint Image-Text Attention.
Internal format: [S, B, H] for transformer blocks, [B, S, H] for attention.
"""
def __init__(
self,
hidden_size: int = 4096,
num_attention_heads: int = 32,
num_layers: int = 36,
ffn_hidden_size: int = 12288,
in_channels: int = 128,
out_channels: int = 128,
patch_size: int = 1,
text_in_dim: int = 3072,
rope_theta: int = 256,
rope_axes_dim: Tuple[int, int, int] = (32, 48, 48),
eps: float = 1e-6,
qk_layernorm: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.head_dim = hidden_size // num_attention_heads
self.num_layers = num_layers
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.text_in_dim = text_in_dim
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size)
self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0)
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size)
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size))
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
self.layers = nn.ModuleList([
ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm)
for _ in range(num_layers)
])
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps)
self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
nn.init.zeros_(self.final_linear.weight)
nn.init.zeros_(self.final_linear.bias)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
text_bth: torch.Tensor,
text_lens: torch.Tensor,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
) -> torch.Tensor:
device, dtype = hidden_states.device, hidden_states.dtype
B, C, H, W = hidden_states.shape
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
N_img = Hp * Wp
img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous()
if self.text_proj is not None and text_bth.numel() > 0:
text_bth = self.text_proj(text_bth)
Tmax = text_bth.shape[1]
text_sbh = text_bth.transpose(0, 1).contiguous()
x = torch.cat([img_sbh, text_sbh], dim=0)
S = x.shape[0]
text_ids = torch.cat([
torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1),
torch.zeros((B, Tmax, 2), device=device)
], dim=-1) if Tmax > 0 else torch.zeros((B, 0, 3), device=device)
grid_yx = torch.stack(
torch.meshgrid(torch.arange(Hp, device=device, dtype=torch.float32),
torch.arange(Wp, device=device, dtype=torch.float32), indexing="ij"),
dim=-1
).reshape(-1, 2)
image_ids = torch.cat([
text_lens.float().view(B, 1, 1).expand(-1, N_img, -1),
grid_yx.view(1, N_img, 2).expand(B, -1, -1)
], dim=-1)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool)
attention_mask = torch.cat([
torch.ones((B, N_img), device=device, dtype=torch.bool),
valid_text
], dim=1)[:, None, None, :]
sample = self.time_proj(timestep.to(dtype))
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
c = self.time_embedding(sample)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
t.unsqueeze(0).expand(S, -1, -1).contiguous()
for t in self.adaLN_modulation(c).chunk(6, dim=-1)
]
for layer in self.layers:
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
if torch.is_grad_enabled() and use_gradient_checkpointing:
x = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x,
rotary_pos_emb,
temb,
attention_mask,
)
else:
x = layer(x, rotary_pos_emb, temb, attention_mask)
x = self.final_norm(x, c).type_as(x)
patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous()
output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W)
return output