mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -429,6 +429,7 @@ flux_series = [
|
||||
"extra_kwargs": {"disable_guidance_embedder": True},
|
||||
},
|
||||
]
|
||||
|
||||
flux2_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
|
||||
@@ -451,4 +452,35 @@ flux2_series = [
|
||||
},
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series
|
||||
z_image_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
|
||||
"model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
|
||||
"model_name": "z_image_dit",
|
||||
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
|
||||
"model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
|
||||
"model_name": "z_image_text_encoder",
|
||||
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
||||
"model_name": "flux_vae_encoder",
|
||||
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
|
||||
"extra_kwargs": {"use_conv_attention": False},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
||||
"model_name": "flux_vae_decoder",
|
||||
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
||||
"extra_kwargs": {"use_conv_attention": False},
|
||||
},
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
||||
|
||||
@@ -150,25 +150,75 @@ class ConvAttention(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VAEAttentionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
ConvAttention(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
bias_q=True,
|
||||
bias_kv=True,
|
||||
bias_out=True
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
if use_conv_attention:
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
ConvAttention(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
bias_q=True,
|
||||
bias_kv=True,
|
||||
bias_out=True
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
else:
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
Attention(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
bias_q=True,
|
||||
bias_kv=True,
|
||||
bias_out=True
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
@@ -244,7 +294,7 @@ class DownSampler(torch.nn.Module):
|
||||
|
||||
|
||||
class FluxVAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, use_conv_attention=True):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
@@ -253,7 +303,7 @@ class FluxVAEDecoder(torch.nn.Module):
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
@@ -316,7 +366,7 @@ class FluxVAEDecoder(torch.nn.Module):
|
||||
|
||||
|
||||
class FluxVAEEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, use_conv_attention=True):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
@@ -340,7 +390,7 @@ class FluxVAEEncoder(torch.nn.Module):
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
])
|
||||
|
||||
|
||||
621
diffsynth/models/z_image_dit.py
Normal file
621
diffsynth/models/z_image_dit.py
Normal file
@@ -0,0 +1,621 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from torch.nn import RMSNorm
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
ADALN_EMBED_DIM = 256
|
||||
SEQ_MULTI_OF = 32
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
if mid_size is None:
|
||||
mid_size = out_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(
|
||||
frequency_embedding_size,
|
||||
mid_size,
|
||||
bias=True,
|
||||
),
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
mid_size,
|
||||
out_size,
|
||||
bias=True,
|
||||
),
|
||||
)
|
||||
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
||||
return t_emb
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.ModuleList([torch.nn.Linear(dim_inner, q_dim, bias=bias_out)])
|
||||
|
||||
self.norm_q = RMSNorm(head_dim, eps=1e-5)
|
||||
self.norm_k = RMSNorm(head_dim, eps=1e-5)
|
||||
|
||||
def forward(self, hidden_states, freqs_cis):
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(-1, (self.num_heads, -1))
|
||||
key = key.unflatten(-1, (self.num_heads, -1))
|
||||
value = value.unflatten(-1, (self.num_heads, -1))
|
||||
|
||||
# Apply Norms
|
||||
if self.norm_q is not None:
|
||||
query = self.norm_q(query)
|
||||
if self.norm_k is not None:
|
||||
key = self.norm_k(key)
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in) # todo
|
||||
|
||||
if freqs_cis is not None:
|
||||
query = apply_rotary_emb(query, freqs_cis)
|
||||
key = apply_rotary_emb(key, freqs_cis)
|
||||
|
||||
# Cast to correct dtype
|
||||
dtype = query.dtype
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Compute joint attention
|
||||
hidden_states = attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
output = self.to_out[0](hidden_states)
|
||||
if len(self.to_out) > 1: # dropout
|
||||
output = self.to_out[1](output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ZImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
# Refactored to use diffusers Attention with custom processor
|
||||
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
|
||||
self.attention = Attention(
|
||||
q_dim=dim,
|
||||
num_heads=n_heads,
|
||||
head_dim=dim // n_heads,
|
||||
)
|
||||
|
||||
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x) * scale_msa,
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + gate_mlp * self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x) * scale_mlp,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x),
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
x = x + self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x),
|
||||
)
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
x = self.norm_final(x) * scale.unsqueeze(1)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class RopeEmbedder:
|
||||
def __init__(
|
||||
self,
|
||||
theta: float = 256.0,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (64, 128, 128),
|
||||
):
|
||||
self.theta = theta
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
|
||||
self.freqs_cis = None
|
||||
|
||||
@staticmethod
|
||||
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
|
||||
with torch.device("cpu"):
|
||||
freqs_cis = []
|
||||
for i, (d, e) in enumerate(zip(dim, end)):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
||||
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
||||
freqs = torch.outer(timestep, freqs).float()
|
||||
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
||||
freqs_cis.append(freqs_cis_i)
|
||||
|
||||
return freqs_cis
|
||||
|
||||
def __call__(self, ids: torch.Tensor):
|
||||
assert ids.ndim == 2
|
||||
assert ids.shape[-1] == len(self.axes_dims)
|
||||
device = ids.device
|
||||
|
||||
if self.freqs_cis is None:
|
||||
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
index = ids[:, i]
|
||||
result.append(self.freqs_cis[i][index])
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
|
||||
class ZImageDiT(nn.Module):
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ZImageTransformerBlock"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
all_patch_size=(2,),
|
||||
all_f_patch_size=(1,),
|
||||
in_channels=16,
|
||||
dim=3840,
|
||||
n_layers=30,
|
||||
n_refiner_layers=2,
|
||||
n_heads=30,
|
||||
n_kv_heads=30,
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
cap_feat_dim=2560,
|
||||
rope_theta=256.0,
|
||||
t_scale=1000.0,
|
||||
axes_dims=[32, 48, 48],
|
||||
axes_lens=[1024, 512, 512],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.all_patch_size = all_patch_size
|
||||
self.all_f_patch_size = all_f_patch_size
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
self.t_scale = t_scale
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
assert len(all_patch_size) == len(all_f_patch_size)
|
||||
|
||||
all_x_embedder = {}
|
||||
all_final_layer = {}
|
||||
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
|
||||
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
|
||||
all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
|
||||
|
||||
self.all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
self.all_final_layer = nn.ModuleDict(all_final_layer)
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
1000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps),
|
||||
nn.Linear(cap_feat_dim, dim, bias=True),
|
||||
)
|
||||
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
head_dim = dim // n_heads
|
||||
assert head_dim == sum(axes_dims)
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
|
||||
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
||||
|
||||
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
bsz = len(x)
|
||||
assert len(size) == bsz
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
if start is None:
|
||||
start = (0 for _ in size)
|
||||
|
||||
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||
grids = torch.meshgrid(axes, indexing="ij")
|
||||
return torch.stack(grids, dim=-1)
|
||||
|
||||
def patchify_and_embed(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
all_cap_feats: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
device = all_image[0].device
|
||||
|
||||
all_image_out = []
|
||||
all_image_size = []
|
||||
all_image_pos_ids = []
|
||||
all_image_pad_mask = []
|
||||
all_cap_pos_ids = []
|
||||
all_cap_pad_mask = []
|
||||
all_cap_feats_out = []
|
||||
|
||||
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
||||
### Process Caption
|
||||
cap_ori_len = len(cap_feat)
|
||||
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
||||
# padded position ids
|
||||
cap_padded_pos_ids = self.create_coordinate_grid(
|
||||
size=(cap_ori_len + cap_padding_len, 1, 1),
|
||||
start=(1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
all_cap_pos_ids.append(cap_padded_pos_ids)
|
||||
# pad mask
|
||||
all_cap_pad_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# padded feature
|
||||
cap_padded_feat = torch.cat(
|
||||
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
|
||||
dim=0,
|
||||
)
|
||||
all_cap_feats_out.append(cap_padded_feat)
|
||||
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
all_image_size.append((F, H, W))
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
image_ori_pos_ids = self.create_coordinate_grid(
|
||||
size=(F_tokens, H_tokens, W_tokens),
|
||||
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
image_padding_pos_ids = (
|
||||
self.create_coordinate_grid(
|
||||
size=(1, 1, 1),
|
||||
start=(0, 0, 0),
|
||||
device=device,
|
||||
)
|
||||
.flatten(0, 2)
|
||||
.repeat(image_padding_len, 1)
|
||||
)
|
||||
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
||||
all_image_pos_ids.append(image_padded_pos_ids)
|
||||
# pad mask
|
||||
all_image_pad_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||
all_image_out.append(image_padded_feat)
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_cap_feats_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_image_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
t,
|
||||
cap_feats: List[torch.Tensor],
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
assert patch_size in self.all_patch_size
|
||||
assert f_patch_size in self.all_f_patch_size
|
||||
|
||||
bsz = len(x)
|
||||
device = x[0].device
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
|
||||
adaln_input = t
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_inner_pad_mask,
|
||||
cap_inner_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
|
||||
# x embed & refine
|
||||
x_item_seqlens = [len(_) for _ in x]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||
x_max_item_seqlen = max(x_item_seqlens)
|
||||
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
x = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=x,
|
||||
attn_mask=x_attn_mask,
|
||||
freqs_cis=x_freqs_cis,
|
||||
adaln_input=adaln_input,
|
||||
)
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
|
||||
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||
|
||||
cap_feats = torch.cat(cap_feats, dim=0)
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
||||
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=cap_feats,
|
||||
attn_mask=cap_attn_mask,
|
||||
freqs_cis=cap_freqs_cis,
|
||||
)
|
||||
|
||||
# unified
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
|
||||
for layer in self.layers:
|
||||
unified = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=unified,
|
||||
attn_mask=unified_attn_mask,
|
||||
freqs_cis=unified_freqs_cis,
|
||||
adaln_input=adaln_input,
|
||||
)
|
||||
|
||||
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||
unified = list(unified.unbind(dim=0))
|
||||
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
||||
|
||||
return x, {}
|
||||
41
diffsynth/models/z_image_text_encoder.py
Normal file
41
diffsynth/models/z_image_text_encoder.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from transformers import Qwen3Model, Qwen3Config
|
||||
import torch
|
||||
|
||||
|
||||
class ZImageTextEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
config = Qwen3Config(**{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2560,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9728,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 36,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": True,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.0",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936
|
||||
})
|
||||
self.model = Qwen3Model(config)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
257
diffsynth/pipelines/z_image.py
Normal file
257
diffsynth/pipelines/z_image.py
Normal file
@@ -0,0 +1,257 @@
|
||||
import torch, math
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from typing import Union, List, Optional, Tuple
|
||||
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||
from ..models.z_image_dit import ZImageDiT
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
|
||||
|
||||
class ZImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.text_encoder: ZImageTextEncoder = None
|
||||
self.dit: ZImageDiT = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
ZImageUnit_ShapeChecker(),
|
||||
ZImageUnit_PromptEmbedder(),
|
||||
ZImageUnit_NoiseInitializer(),
|
||||
ZImageUnit_InputImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_z_image
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("z_image_dit")
|
||||
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
||||
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 8,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae_decoder(inputs_shared["latents"])
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class ZImageUnit_ShapeChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width"),
|
||||
output_params=("height", "width"),
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, height, width):
|
||||
height, width = pipe.check_resize_height_width(height, width)
|
||||
return {"height": height, "width": width}
|
||||
|
||||
|
||||
class ZImageUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
pipe,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.FloatTensor]:
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt_item},
|
||||
]
|
||||
prompt_item = pipe.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
prompt[i] = prompt_item
|
||||
|
||||
text_inputs = pipe.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = pipe.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
|
||||
for i in range(len(prompt_embeds)):
|
||||
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||
|
||||
return embeddings_list
|
||||
|
||||
def process(self, pipe: ZImagePipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
|
||||
class ZImageUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class ZImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae_encoder(image)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
def model_fn_z_image(
|
||||
dit: ZImageDiT,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
latents = [rearrange(latents, "B C H W -> C B H W")]
|
||||
timestep = (1000 - timestep) / 1000
|
||||
model_output = dit(
|
||||
latents,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)[0][0]
|
||||
model_output = -model_output
|
||||
model_output = rearrange(model_output, "C B H W -> B C H W")
|
||||
return model_output
|
||||
@@ -262,3 +262,121 @@ def FluxVAEDecoderStateDictConverter(state_dict):
|
||||
param = state_dict[name]
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def FluxVAEEncoderStateDictConverterDiffusers(state_dict):
|
||||
# architecture
|
||||
block_types = [
|
||||
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||
'ResnetBlock', 'ResnetBlock',
|
||||
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
|
||||
]
|
||||
|
||||
# Rename each parameter
|
||||
local_rename_dict = {
|
||||
"quant_conv": "quant_conv",
|
||||
"encoder.conv_in": "conv_in",
|
||||
"encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
|
||||
"encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
|
||||
"encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
|
||||
"encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
|
||||
"encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
|
||||
"encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
|
||||
"encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
|
||||
"encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
|
||||
"encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
|
||||
"encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
|
||||
"encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
|
||||
"encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
|
||||
"encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
|
||||
"encoder.conv_norm_out": "conv_norm_out",
|
||||
"encoder.conv_out": "conv_out",
|
||||
}
|
||||
name_list = sorted([name for name in state_dict])
|
||||
rename_dict = {}
|
||||
block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
||||
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
name_prefix = ".".join(names[:-1])
|
||||
if name_prefix in local_rename_dict:
|
||||
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
||||
elif name.startswith("encoder.down_blocks"):
|
||||
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||
block_id[block_type] += 1
|
||||
last_block_type_with_id[block_type] = block_type_with_id
|
||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||
block_id[block_type] += 1
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
names = ["blocks", str(block_id[block_type])] + names[5:]
|
||||
rename_dict[name] = ".".join(names)
|
||||
|
||||
# Convert state_dict
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def FluxVAEDecoderStateDictConverterDiffusers(state_dict):
|
||||
# architecture
|
||||
block_types = [
|
||||
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
|
||||
]
|
||||
|
||||
# Rename each parameter
|
||||
local_rename_dict = {
|
||||
"post_quant_conv": "post_quant_conv",
|
||||
"decoder.conv_in": "conv_in",
|
||||
"decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
|
||||
"decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
|
||||
"decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
|
||||
"decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
|
||||
"decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
|
||||
"decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
|
||||
"decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
|
||||
"decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
|
||||
"decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
|
||||
"decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
|
||||
"decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
|
||||
"decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
|
||||
"decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
|
||||
"decoder.conv_norm_out": "conv_norm_out",
|
||||
"decoder.conv_out": "conv_out",
|
||||
}
|
||||
name_list = sorted([name for name in state_dict])
|
||||
rename_dict = {}
|
||||
block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
|
||||
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
name_prefix = ".".join(names[:-1])
|
||||
if name_prefix in local_rename_dict:
|
||||
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
||||
elif name.startswith("decoder.up_blocks"):
|
||||
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||
block_id[block_type] += 1
|
||||
last_block_type_with_id[block_type] = block_type_with_id
|
||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||
block_id[block_type] += 1
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
names = ["blocks", str(block_id[block_type])] + names[5:]
|
||||
rename_dict[name] = ".".join(names)
|
||||
|
||||
# Convert state_dict
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = state_dict[name]
|
||||
return state_dict_
|
||||
17
examples/z_image/model_inference/Z-Image-Turbo.py
Normal file
17
examples/z_image/model_inference/Z-Image-Turbo.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
||||
image.save("image.jpg")
|
||||
15
examples/z_image/model_training/lora/Z-Image-Turbo.sh
Normal file
15
examples/z_image/model_training/lora/Z-Image-Turbo.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
accelerate launch examples/z_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Z-Image-Turbo_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8
|
||||
143
examples/z_image/model_training/train.py
Normal file
143
examples/z_image/model_training/train.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import torch, os, argparse, accelerate
|
||||
from diffsynth.core import UnifiedDataset
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||
from diffsynth.diffusion import *
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
class ZImageTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
tokenizer_path=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||
preset_lora_path=None, preset_lora_model=None,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
fp8_models=None,
|
||||
offload_models=None,
|
||||
device="cpu",
|
||||
task="sft",
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||
tokenizer_config = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||
self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||
|
||||
# Training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe, trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||
preset_lora_path, preset_lora_model,
|
||||
task=task,
|
||||
)
|
||||
|
||||
# Other configs
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
self.fp8_models = fp8_models
|
||||
self.task = task
|
||||
self.task_to_loss = {
|
||||
"sft:data_process": lambda pipe, *args: args,
|
||||
"direct_distill:data_process": lambda pipe, *args: args,
|
||||
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
}
|
||||
|
||||
def get_pipeline_inputs(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
inputs_shared = {
|
||||
# Assume you are using this pipeline for inference,
|
||||
# please fill in the input parameters.
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
# Please do not modify the following parameters
|
||||
# unless you clearly know what this will cause.
|
||||
"embedded_guidance": 1.0,
|
||||
"cfg_scale": 1,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
}
|
||||
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||
for unit in self.pipe.units:
|
||||
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||
return loss
|
||||
|
||||
|
||||
def qwen_image_parser():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser = add_general_config(parser)
|
||||
parser = add_image_size_config(parser)
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = qwen_image_parser()
|
||||
args = parser.parse_args()
|
||||
accelerator = accelerate.Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||
)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(
|
||||
base_path=args.dataset_base_path,
|
||||
max_pixels=args.max_pixels,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
height_division_factor=16,
|
||||
width_division_factor=16,
|
||||
)
|
||||
)
|
||||
model = ZImageTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
trainable_models=args.trainable_models,
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_checkpoint=args.lora_checkpoint,
|
||||
preset_lora_path=args.preset_lora_path,
|
||||
preset_lora_model=args.preset_lora_model,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
fp8_models=args.fp8_models,
|
||||
offload_models=args.offload_models,
|
||||
task=args.task,
|
||||
device=accelerator.device,
|
||||
)
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
)
|
||||
launcher_map = {
|
||||
"sft:data_process": launch_data_process_task,
|
||||
"direct_distill:data_process": launch_data_process_task,
|
||||
"sft": launch_training_task,
|
||||
"sft:train": launch_training_task,
|
||||
"direct_distill": launch_training_task,
|
||||
"direct_distill:train": launch_training_task,
|
||||
}
|
||||
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||
@@ -0,0 +1,18 @@
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "/models/train/Z-Image-Turbo_lora/epoch-4.safetensors")
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
||||
image.save("image.jpg")
|
||||
Reference in New Issue
Block a user