Files
DiffSynth-Studio/diffsynth/models/z_image_dit.py
2026-01-07 11:42:19 +08:00

626 lines
22 KiB
Python

import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.nn import RMSNorm
from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
from ..core.gradient import gradient_checkpoint_forward
ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32
class TimestepEmbedder(nn.Module):
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
super().__init__()
if mid_size is None:
mid_size = out_size
self.mlp = nn.Sequential(
nn.Linear(
frequency_embedding_size,
mid_size,
bias=True,
),
nn.SiLU(),
nn.Linear(
mid_size,
out_size,
bias=True,
),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
with torch.amp.autocast("cuda", enabled=False):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(torch.bfloat16))
return t_emb
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.ModuleList([torch.nn.Linear(dim_inner, q_dim, bias=bias_out)])
self.norm_q = RMSNorm(head_dim, eps=1e-5)
self.norm_k = RMSNorm(head_dim, eps=1e-5)
def forward(self, hidden_states, freqs_cis):
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
query = query.unflatten(-1, (self.num_heads, -1))
key = key.unflatten(-1, (self.num_heads, -1))
value = value.unflatten(-1, (self.num_heads, -1))
# Apply Norms
if self.norm_q is not None:
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k(key)
# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
# Cast to correct dtype
dtype = query.dtype
query, key = query.to(dtype), key.to(dtype)
# Compute joint attention
hidden_states = attention_forward(
query,
key,
value,
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
)
# Reshape back
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(dtype)
output = self.to_out[0](hidden_states)
if len(self.to_out) > 1: # dropout
output = self.to_out[1](output)
return output
class ZImageTransformerBlock(nn.Module):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
qk_norm: bool,
modulation=True,
):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
# Refactored to use diffusers Attention with custom processor
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
self.attention = Attention(
q_dim=dim,
num_heads=n_heads,
head_dim=dim // n_heads,
)
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
):
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
# Attention block
attn_out = self.attention(
self.attention_norm1(x) * scale_msa,
freqs_cis=freqs_cis,
)
x = x + gate_msa * self.attention_norm2(attn_out)
# FFN block
x = x + gate_mlp * self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x) * scale_mlp,
)
)
else:
# Attention block
attn_out = self.attention(
self.attention_norm1(x),
freqs_cis=freqs_cis,
)
x = x + self.attention_norm2(attn_out)
# FFN block
x = x + self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x),
)
)
return x
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
)
def forward(self, x, c):
scale = 1.0 + self.adaLN_modulation(c)
x = self.norm_final(x) * scale.unsqueeze(1)
x = self.linear(x)
return x
class RopeEmbedder:
def __init__(
self,
theta: float = 256.0,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (64, 128, 128),
):
self.theta = theta
self.axes_dims = axes_dims
self.axes_lens = axes_lens
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
self.freqs_cis = None
@staticmethod
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
with torch.device("cpu"):
freqs_cis = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
freqs_cis.append(freqs_cis_i)
return freqs_cis
def __call__(self, ids: torch.Tensor):
assert ids.ndim == 2
assert ids.shape[-1] == len(self.axes_dims)
device = ids.device
if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
if IS_NPU_AVAILABLE:
result.append(torch.index_select(self.freqs_cis[i], 0, index))
else:
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)
class ZImageDiT(nn.Module):
_supports_gradient_checkpointing = True
_no_split_modules = ["ZImageTransformerBlock"]
def __init__(
self,
all_patch_size=(2,),
all_f_patch_size=(1,),
in_channels=16,
dim=3840,
n_layers=30,
n_refiner_layers=2,
n_heads=30,
n_kv_heads=30,
norm_eps=1e-5,
qk_norm=True,
cap_feat_dim=2560,
rope_theta=256.0,
t_scale=1000.0,
axes_dims=[32, 48, 48],
axes_lens=[1024, 512, 512],
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.all_patch_size = all_patch_size
self.all_f_patch_size = all_f_patch_size
self.dim = dim
self.n_heads = n_heads
self.rope_theta = rope_theta
self.t_scale = t_scale
self.gradient_checkpointing = False
assert len(all_patch_size) == len(all_f_patch_size)
all_x_embedder = {}
all_final_layer = {}
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
self.all_x_embedder = nn.ModuleDict(all_x_embedder)
self.all_final_layer = nn.ModuleDict(all_final_layer)
self.noise_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
1000 + layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=True,
)
for layer_id in range(n_refiner_layers)
]
)
self.context_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(cap_feat_dim, dim, bias=True),
)
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
self.layers = nn.ModuleList(
[
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
for layer_id in range(n_layers)
]
)
head_dim = dim // n_heads
assert head_dim == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
pH = pW = patch_size
pF = f_patch_size
bsz = len(x)
assert len(size) == bsz
for i in range(bsz):
F, H, W = size[i]
ori_len = (F // pF) * (H // pH) * (W // pW)
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
x[i] = (
x[i][:ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
return x
@staticmethod
def create_coordinate_grid(size, start=None, device=None):
if start is None:
start = (0 for _ in size)
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
grids = torch.meshgrid(axes, indexing="ij")
return torch.stack(grids, dim=-1)
def patchify_and_embed(
self,
all_image: List[torch.Tensor],
all_cap_feats: List[torch.Tensor],
patch_size: int,
f_patch_size: int,
):
pH = pW = patch_size
pF = f_patch_size
device = all_image[0].device
all_image_out = []
all_image_size = []
all_image_pos_ids = []
all_image_pad_mask = []
all_cap_pos_ids = []
all_cap_pad_mask = []
all_cap_feats_out = []
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
### Process Caption
cap_ori_len = len(cap_feat)
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
# padded position ids
cap_padded_pos_ids = self.create_coordinate_grid(
size=(cap_ori_len + cap_padding_len, 1, 1),
start=(1, 0, 0),
device=device,
).flatten(0, 2)
all_cap_pos_ids.append(cap_padded_pos_ids)
# pad mask
all_cap_pad_mask.append(
torch.cat(
[
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
)
# padded feature
cap_padded_feat = torch.cat(
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
dim=0,
)
all_cap_feats_out.append(cap_padded_feat)
### Process Image
C, F, H, W = image.size()
all_image_size.append((F, H, W))
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
image_ori_len = len(image)
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
image_ori_pos_ids = self.create_coordinate_grid(
size=(F_tokens, H_tokens, W_tokens),
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
device=device,
).flatten(0, 2)
image_padding_pos_ids = (
self.create_coordinate_grid(
size=(1, 1, 1),
start=(0, 0, 0),
device=device,
)
.flatten(0, 2)
.repeat(image_padding_len, 1)
)
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
all_image_pos_ids.append(image_padded_pos_ids)
# pad mask
all_image_pad_mask.append(
torch.cat(
[
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
)
# padded feature
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
all_image_out.append(image_padded_feat)
return (
all_image_out,
all_cap_feats_out,
all_image_size,
all_image_pos_ids,
all_cap_pos_ids,
all_image_pad_mask,
all_cap_pad_mask,
)
def forward(
self,
x: List[torch.Tensor],
t,
cap_feats: List[torch.Tensor],
patch_size=2,
f_patch_size=1,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size
bsz = len(x)
device = x[0].device
t = t * self.t_scale
t = self.t_embedder(t)
adaln_input = t
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
# x embed & refine
x_item_seqlens = [len(_) for _ in x]
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
x_max_item_seqlen = max(x_item_seqlens)
x = torch.cat(x, dim=0)
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token.to(dtype=x.dtype, device=x.device)
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
for layer in self.noise_refiner:
x = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=x,
attn_mask=x_attn_mask,
freqs_cis=x_freqs_cis,
adaln_input=adaln_input,
)
# cap embed & refine
cap_item_seqlens = [len(_) for _ in cap_feats]
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
cap_max_item_seqlen = max(cap_item_seqlens)
cap_feats = torch.cat(cap_feats, dim=0)
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token.to(dtype=x.dtype, device=x.device)
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
for layer in self.context_refiner:
cap_feats = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=cap_feats,
attn_mask=cap_attn_mask,
freqs_cis=cap_freqs_cis,
)
# unified
unified = []
unified_freqs_cis = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
assert unified_item_seqlens == [len(_) for _ in unified]
unified_max_item_seqlen = max(unified_item_seqlens)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
for layer in self.layers:
unified = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=unified,
attn_mask=unified_attn_mask,
freqs_cis=unified_freqs_cis,
adaln_input=adaln_input,
)
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
unified = list(unified.unbind(dim=0))
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
return x, {}