mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
@@ -429,6 +429,7 @@ flux_series = [
|
|||||||
"extra_kwargs": {"disable_guidance_embedder": True},
|
"extra_kwargs": {"disable_guidance_embedder": True},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
flux2_series = [
|
flux2_series = [
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
|
||||||
@@ -451,4 +452,35 @@ flux2_series = [
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series
|
z_image_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
|
||||||
|
"model_name": "z_image_dit",
|
||||||
|
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
|
||||||
|
"model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
|
||||||
|
"model_name": "z_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
||||||
|
"model_name": "flux_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
|
||||||
|
"extra_kwargs": {"use_conv_attention": False},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
||||||
|
"model_name": "flux_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
||||||
|
"extra_kwargs": {"use_conv_attention": False},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
||||||
|
|||||||
@@ -150,25 +150,75 @@ class ConvAttention(torch.nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||||
|
super().__init__()
|
||||||
|
dim_inner = head_dim * num_heads
|
||||||
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||||
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
batch_size = encoder_hidden_states.shape[0]
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(encoder_hidden_states)
|
||||||
|
v = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class VAEAttentionBlock(torch.nn.Module):
|
class VAEAttentionBlock(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = num_attention_heads * attention_head_dim
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
|
|
||||||
self.transformer_blocks = torch.nn.ModuleList([
|
if use_conv_attention:
|
||||||
ConvAttention(
|
self.transformer_blocks = torch.nn.ModuleList([
|
||||||
inner_dim,
|
ConvAttention(
|
||||||
num_attention_heads,
|
inner_dim,
|
||||||
attention_head_dim,
|
num_attention_heads,
|
||||||
bias_q=True,
|
attention_head_dim,
|
||||||
bias_kv=True,
|
bias_q=True,
|
||||||
bias_out=True
|
bias_kv=True,
|
||||||
)
|
bias_out=True
|
||||||
for d in range(num_layers)
|
)
|
||||||
])
|
for d in range(num_layers)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.transformer_blocks = torch.nn.ModuleList([
|
||||||
|
Attention(
|
||||||
|
inner_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_head_dim,
|
||||||
|
bias_q=True,
|
||||||
|
bias_kv=True,
|
||||||
|
bias_out=True
|
||||||
|
)
|
||||||
|
for d in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||||
batch, _, height, width = hidden_states.shape
|
batch, _, height, width = hidden_states.shape
|
||||||
@@ -244,7 +294,7 @@ class DownSampler(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FluxVAEDecoder(torch.nn.Module):
|
class FluxVAEDecoder(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, use_conv_attention=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scaling_factor = 0.3611
|
self.scaling_factor = 0.3611
|
||||||
self.shift_factor = 0.1159
|
self.shift_factor = 0.1159
|
||||||
@@ -253,7 +303,7 @@ class FluxVAEDecoder(torch.nn.Module):
|
|||||||
self.blocks = torch.nn.ModuleList([
|
self.blocks = torch.nn.ModuleList([
|
||||||
# UNetMidBlock2D
|
# UNetMidBlock2D
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
# UpDecoderBlock2D
|
# UpDecoderBlock2D
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
@@ -316,7 +366,7 @@ class FluxVAEDecoder(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FluxVAEEncoder(torch.nn.Module):
|
class FluxVAEEncoder(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, use_conv_attention=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scaling_factor = 0.3611
|
self.scaling_factor = 0.3611
|
||||||
self.shift_factor = 0.1159
|
self.shift_factor = 0.1159
|
||||||
@@ -340,7 +390,7 @@ class FluxVAEEncoder(torch.nn.Module):
|
|||||||
ResnetBlock(512, 512, eps=1e-6),
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
# UNetMidBlock2D
|
# UNetMidBlock2D
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|||||||
621
diffsynth/models/z_image_dit.py
Normal file
621
diffsynth/models/z_image_dit.py
Normal file
@@ -0,0 +1,621 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
from torch.nn import RMSNorm
|
||||||
|
from ..core.attention import attention_forward
|
||||||
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
|
||||||
|
ADALN_EMBED_DIM = 256
|
||||||
|
SEQ_MULTI_OF = 32
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
|
||||||
|
super().__init__()
|
||||||
|
if mid_size is None:
|
||||||
|
mid_size = out_size
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
frequency_embedding_size,
|
||||||
|
mid_size,
|
||||||
|
bias=True,
|
||||||
|
),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(
|
||||||
|
mid_size,
|
||||||
|
out_size,
|
||||||
|
bias=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
with torch.amp.autocast("cuda", enabled=False):
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||||
|
)
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||||
|
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim: int, hidden_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def _forward_silu_gating(self, x1, x3):
|
||||||
|
return F.silu(x1) * x3
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||||
|
super().__init__()
|
||||||
|
dim_inner = head_dim * num_heads
|
||||||
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||||
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_out = torch.nn.ModuleList([torch.nn.Linear(dim_inner, q_dim, bias=bias_out)])
|
||||||
|
|
||||||
|
self.norm_q = RMSNorm(head_dim, eps=1e-5)
|
||||||
|
self.norm_k = RMSNorm(head_dim, eps=1e-5)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, freqs_cis):
|
||||||
|
query = self.to_q(hidden_states)
|
||||||
|
key = self.to_k(hidden_states)
|
||||||
|
value = self.to_v(hidden_states)
|
||||||
|
|
||||||
|
query = query.unflatten(-1, (self.num_heads, -1))
|
||||||
|
key = key.unflatten(-1, (self.num_heads, -1))
|
||||||
|
value = value.unflatten(-1, (self.num_heads, -1))
|
||||||
|
|
||||||
|
# Apply Norms
|
||||||
|
if self.norm_q is not None:
|
||||||
|
query = self.norm_q(query)
|
||||||
|
if self.norm_k is not None:
|
||||||
|
key = self.norm_k(key)
|
||||||
|
|
||||||
|
# Apply RoPE
|
||||||
|
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||||
|
with torch.amp.autocast("cuda", enabled=False):
|
||||||
|
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(2)
|
||||||
|
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||||
|
return x_out.type_as(x_in) # todo
|
||||||
|
|
||||||
|
if freqs_cis is not None:
|
||||||
|
query = apply_rotary_emb(query, freqs_cis)
|
||||||
|
key = apply_rotary_emb(key, freqs_cis)
|
||||||
|
|
||||||
|
# Cast to correct dtype
|
||||||
|
dtype = query.dtype
|
||||||
|
query, key = query.to(dtype), key.to(dtype)
|
||||||
|
|
||||||
|
# Compute joint attention
|
||||||
|
hidden_states = attention_forward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape back
|
||||||
|
hidden_states = hidden_states.flatten(2, 3)
|
||||||
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
|
||||||
|
output = self.to_out[0](hidden_states)
|
||||||
|
if len(self.to_out) > 1: # dropout
|
||||||
|
output = self.to_out[1](output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
norm_eps: float,
|
||||||
|
qk_norm: bool,
|
||||||
|
modulation=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.head_dim = dim // n_heads
|
||||||
|
|
||||||
|
# Refactored to use diffusers Attention with custom processor
|
||||||
|
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
|
||||||
|
self.attention = Attention(
|
||||||
|
q_dim=dim,
|
||||||
|
num_heads=n_heads,
|
||||||
|
head_dim=dim // n_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
|
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
|
self.modulation = modulation
|
||||||
|
if modulation:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
adaln_input: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if self.modulation:
|
||||||
|
assert adaln_input is not None
|
||||||
|
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||||
|
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||||
|
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||||
|
|
||||||
|
# Attention block
|
||||||
|
attn_out = self.attention(
|
||||||
|
self.attention_norm1(x) * scale_msa,
|
||||||
|
freqs_cis=freqs_cis,
|
||||||
|
)
|
||||||
|
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||||
|
|
||||||
|
# FFN block
|
||||||
|
x = x + gate_mlp * self.ffn_norm2(
|
||||||
|
self.feed_forward(
|
||||||
|
self.ffn_norm1(x) * scale_mlp,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Attention block
|
||||||
|
attn_out = self.attention(
|
||||||
|
self.attention_norm1(x),
|
||||||
|
freqs_cis=freqs_cis,
|
||||||
|
)
|
||||||
|
x = x + self.attention_norm2(attn_out)
|
||||||
|
|
||||||
|
# FFN block
|
||||||
|
x = x + self.ffn_norm2(
|
||||||
|
self.feed_forward(
|
||||||
|
self.ffn_norm1(x),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
def __init__(self, hidden_size, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||||
|
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
scale = 1.0 + self.adaLN_modulation(c)
|
||||||
|
x = self.norm_final(x) * scale.unsqueeze(1)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RopeEmbedder:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
theta: float = 256.0,
|
||||||
|
axes_dims: List[int] = (16, 56, 56),
|
||||||
|
axes_lens: List[int] = (64, 128, 128),
|
||||||
|
):
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dims = axes_dims
|
||||||
|
self.axes_lens = axes_lens
|
||||||
|
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
|
||||||
|
self.freqs_cis = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
|
||||||
|
with torch.device("cpu"):
|
||||||
|
freqs_cis = []
|
||||||
|
for i, (d, e) in enumerate(zip(dim, end)):
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
||||||
|
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
||||||
|
freqs = torch.outer(timestep, freqs).float()
|
||||||
|
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
||||||
|
freqs_cis.append(freqs_cis_i)
|
||||||
|
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
def __call__(self, ids: torch.Tensor):
|
||||||
|
assert ids.ndim == 2
|
||||||
|
assert ids.shape[-1] == len(self.axes_dims)
|
||||||
|
device = ids.device
|
||||||
|
|
||||||
|
if self.freqs_cis is None:
|
||||||
|
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||||
|
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for i in range(len(self.axes_dims)):
|
||||||
|
index = ids[:, i]
|
||||||
|
result.append(self.freqs_cis[i][index])
|
||||||
|
return torch.cat(result, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageDiT(nn.Module):
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = ["ZImageTransformerBlock"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
all_patch_size=(2,),
|
||||||
|
all_f_patch_size=(1,),
|
||||||
|
in_channels=16,
|
||||||
|
dim=3840,
|
||||||
|
n_layers=30,
|
||||||
|
n_refiner_layers=2,
|
||||||
|
n_heads=30,
|
||||||
|
n_kv_heads=30,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
qk_norm=True,
|
||||||
|
cap_feat_dim=2560,
|
||||||
|
rope_theta=256.0,
|
||||||
|
t_scale=1000.0,
|
||||||
|
axes_dims=[32, 48, 48],
|
||||||
|
axes_lens=[1024, 512, 512],
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels
|
||||||
|
self.all_patch_size = all_patch_size
|
||||||
|
self.all_f_patch_size = all_f_patch_size
|
||||||
|
self.dim = dim
|
||||||
|
self.n_heads = n_heads
|
||||||
|
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.t_scale = t_scale
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
assert len(all_patch_size) == len(all_f_patch_size)
|
||||||
|
|
||||||
|
all_x_embedder = {}
|
||||||
|
all_final_layer = {}
|
||||||
|
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
|
||||||
|
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
|
||||||
|
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||||
|
|
||||||
|
final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
|
||||||
|
all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
|
||||||
|
|
||||||
|
self.all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||||
|
self.all_final_layer = nn.ModuleDict(all_final_layer)
|
||||||
|
self.noise_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ZImageTransformerBlock(
|
||||||
|
1000 + layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
modulation=True,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.context_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ZImageTransformerBlock(
|
||||||
|
layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
modulation=False,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
|
||||||
|
self.cap_embedder = nn.Sequential(
|
||||||
|
RMSNorm(cap_feat_dim, eps=norm_eps),
|
||||||
|
nn.Linear(cap_feat_dim, dim, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||||
|
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
|
||||||
|
for layer_id in range(n_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
head_dim = dim // n_heads
|
||||||
|
assert head_dim == sum(axes_dims)
|
||||||
|
self.axes_dims = axes_dims
|
||||||
|
self.axes_lens = axes_lens
|
||||||
|
|
||||||
|
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
||||||
|
|
||||||
|
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
|
||||||
|
pH = pW = patch_size
|
||||||
|
pF = f_patch_size
|
||||||
|
bsz = len(x)
|
||||||
|
assert len(size) == bsz
|
||||||
|
for i in range(bsz):
|
||||||
|
F, H, W = size[i]
|
||||||
|
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||||
|
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||||
|
x[i] = (
|
||||||
|
x[i][:ori_len]
|
||||||
|
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||||
|
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||||
|
.reshape(self.out_channels, F, H, W)
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_coordinate_grid(size, start=None, device=None):
|
||||||
|
if start is None:
|
||||||
|
start = (0 for _ in size)
|
||||||
|
|
||||||
|
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||||
|
grids = torch.meshgrid(axes, indexing="ij")
|
||||||
|
return torch.stack(grids, dim=-1)
|
||||||
|
|
||||||
|
def patchify_and_embed(
|
||||||
|
self,
|
||||||
|
all_image: List[torch.Tensor],
|
||||||
|
all_cap_feats: List[torch.Tensor],
|
||||||
|
patch_size: int,
|
||||||
|
f_patch_size: int,
|
||||||
|
):
|
||||||
|
pH = pW = patch_size
|
||||||
|
pF = f_patch_size
|
||||||
|
device = all_image[0].device
|
||||||
|
|
||||||
|
all_image_out = []
|
||||||
|
all_image_size = []
|
||||||
|
all_image_pos_ids = []
|
||||||
|
all_image_pad_mask = []
|
||||||
|
all_cap_pos_ids = []
|
||||||
|
all_cap_pad_mask = []
|
||||||
|
all_cap_feats_out = []
|
||||||
|
|
||||||
|
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
||||||
|
### Process Caption
|
||||||
|
cap_ori_len = len(cap_feat)
|
||||||
|
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
||||||
|
# padded position ids
|
||||||
|
cap_padded_pos_ids = self.create_coordinate_grid(
|
||||||
|
size=(cap_ori_len + cap_padding_len, 1, 1),
|
||||||
|
start=(1, 0, 0),
|
||||||
|
device=device,
|
||||||
|
).flatten(0, 2)
|
||||||
|
all_cap_pos_ids.append(cap_padded_pos_ids)
|
||||||
|
# pad mask
|
||||||
|
all_cap_pad_mask.append(
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
|
||||||
|
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# padded feature
|
||||||
|
cap_padded_feat = torch.cat(
|
||||||
|
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
all_cap_feats_out.append(cap_padded_feat)
|
||||||
|
|
||||||
|
### Process Image
|
||||||
|
C, F, H, W = image.size()
|
||||||
|
all_image_size.append((F, H, W))
|
||||||
|
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||||
|
|
||||||
|
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||||
|
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||||
|
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||||
|
|
||||||
|
image_ori_len = len(image)
|
||||||
|
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||||
|
|
||||||
|
image_ori_pos_ids = self.create_coordinate_grid(
|
||||||
|
size=(F_tokens, H_tokens, W_tokens),
|
||||||
|
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
|
||||||
|
device=device,
|
||||||
|
).flatten(0, 2)
|
||||||
|
image_padding_pos_ids = (
|
||||||
|
self.create_coordinate_grid(
|
||||||
|
size=(1, 1, 1),
|
||||||
|
start=(0, 0, 0),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
.flatten(0, 2)
|
||||||
|
.repeat(image_padding_len, 1)
|
||||||
|
)
|
||||||
|
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
||||||
|
all_image_pos_ids.append(image_padded_pos_ids)
|
||||||
|
# pad mask
|
||||||
|
all_image_pad_mask.append(
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||||
|
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# padded feature
|
||||||
|
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||||
|
all_image_out.append(image_padded_feat)
|
||||||
|
|
||||||
|
return (
|
||||||
|
all_image_out,
|
||||||
|
all_cap_feats_out,
|
||||||
|
all_image_size,
|
||||||
|
all_image_pos_ids,
|
||||||
|
all_cap_pos_ids,
|
||||||
|
all_image_pad_mask,
|
||||||
|
all_cap_pad_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: List[torch.Tensor],
|
||||||
|
t,
|
||||||
|
cap_feats: List[torch.Tensor],
|
||||||
|
patch_size=2,
|
||||||
|
f_patch_size=1,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
):
|
||||||
|
assert patch_size in self.all_patch_size
|
||||||
|
assert f_patch_size in self.all_f_patch_size
|
||||||
|
|
||||||
|
bsz = len(x)
|
||||||
|
device = x[0].device
|
||||||
|
t = t * self.t_scale
|
||||||
|
t = self.t_embedder(t)
|
||||||
|
|
||||||
|
adaln_input = t
|
||||||
|
|
||||||
|
(
|
||||||
|
x,
|
||||||
|
cap_feats,
|
||||||
|
x_size,
|
||||||
|
x_pos_ids,
|
||||||
|
cap_pos_ids,
|
||||||
|
x_inner_pad_mask,
|
||||||
|
cap_inner_pad_mask,
|
||||||
|
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||||
|
|
||||||
|
# x embed & refine
|
||||||
|
x_item_seqlens = [len(_) for _ in x]
|
||||||
|
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||||
|
x_max_item_seqlen = max(x_item_seqlens)
|
||||||
|
|
||||||
|
x = torch.cat(x, dim=0)
|
||||||
|
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||||
|
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||||
|
x = list(x.split(x_item_seqlens, dim=0))
|
||||||
|
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||||
|
|
||||||
|
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||||
|
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||||
|
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||||
|
for i, seq_len in enumerate(x_item_seqlens):
|
||||||
|
x_attn_mask[i, :seq_len] = 1
|
||||||
|
|
||||||
|
for layer in self.noise_refiner:
|
||||||
|
x = gradient_checkpoint_forward(
|
||||||
|
layer,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
x=x,
|
||||||
|
attn_mask=x_attn_mask,
|
||||||
|
freqs_cis=x_freqs_cis,
|
||||||
|
adaln_input=adaln_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cap embed & refine
|
||||||
|
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||||
|
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
|
||||||
|
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||||
|
|
||||||
|
cap_feats = torch.cat(cap_feats, dim=0)
|
||||||
|
cap_feats = self.cap_embedder(cap_feats)
|
||||||
|
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
||||||
|
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||||
|
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
||||||
|
|
||||||
|
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||||
|
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||||
|
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||||
|
for i, seq_len in enumerate(cap_item_seqlens):
|
||||||
|
cap_attn_mask[i, :seq_len] = 1
|
||||||
|
|
||||||
|
for layer in self.context_refiner:
|
||||||
|
cap_feats = gradient_checkpoint_forward(
|
||||||
|
layer,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
x=cap_feats,
|
||||||
|
attn_mask=cap_attn_mask,
|
||||||
|
freqs_cis=cap_freqs_cis,
|
||||||
|
)
|
||||||
|
|
||||||
|
# unified
|
||||||
|
unified = []
|
||||||
|
unified_freqs_cis = []
|
||||||
|
for i in range(bsz):
|
||||||
|
x_len = x_item_seqlens[i]
|
||||||
|
cap_len = cap_item_seqlens[i]
|
||||||
|
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||||
|
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||||
|
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||||
|
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||||
|
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||||
|
|
||||||
|
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||||
|
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||||
|
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||||
|
for i, seq_len in enumerate(unified_item_seqlens):
|
||||||
|
unified_attn_mask[i, :seq_len] = 1
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
unified = gradient_checkpoint_forward(
|
||||||
|
layer,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
x=unified,
|
||||||
|
attn_mask=unified_attn_mask,
|
||||||
|
freqs_cis=unified_freqs_cis,
|
||||||
|
adaln_input=adaln_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||||
|
unified = list(unified.unbind(dim=0))
|
||||||
|
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
||||||
|
|
||||||
|
return x, {}
|
||||||
41
diffsynth/models/z_image_text_encoder.py
Normal file
41
diffsynth/models/z_image_text_encoder.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from transformers import Qwen3Model, Qwen3Config
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageTextEncoder(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
config = Qwen3Config(**{
|
||||||
|
"architectures": [
|
||||||
|
"Qwen3ForCausalLM"
|
||||||
|
],
|
||||||
|
"attention_bias": False,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 2560,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 9728,
|
||||||
|
"max_position_embeddings": 40960,
|
||||||
|
"max_window_layers": 36,
|
||||||
|
"model_type": "qwen3",
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 36,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": None,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"sliding_window": None,
|
||||||
|
"tie_word_embeddings": True,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.51.0",
|
||||||
|
"use_cache": True,
|
||||||
|
"use_sliding_window": False,
|
||||||
|
"vocab_size": 151936
|
||||||
|
})
|
||||||
|
self.model = Qwen3Model(config)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.model(*args, **kwargs)
|
||||||
257
diffsynth/pipelines/z_image.py
Normal file
257
diffsynth/pipelines/z_image.py
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
import torch, math
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Union
|
||||||
|
from tqdm import tqdm
|
||||||
|
from einops import rearrange
|
||||||
|
import numpy as np
|
||||||
|
from typing import Union, List, Optional, Tuple
|
||||||
|
|
||||||
|
from ..diffusion import FlowMatchScheduler
|
||||||
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||||
|
from ..models.z_image_dit import ZImageDiT
|
||||||
|
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||||
|
|
||||||
|
|
||||||
|
class ZImagePipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler()
|
||||||
|
self.text_encoder: ZImageTextEncoder = None
|
||||||
|
self.dit: ZImageDiT = None
|
||||||
|
self.vae_encoder: FluxVAEEncoder = None
|
||||||
|
self.vae_decoder: FluxVAEDecoder = None
|
||||||
|
self.tokenizer: AutoTokenizer = None
|
||||||
|
self.in_iteration_models = ("dit",)
|
||||||
|
self.units = [
|
||||||
|
ZImageUnit_ShapeChecker(),
|
||||||
|
ZImageUnit_PromptEmbedder(),
|
||||||
|
ZImageUnit_NoiseInitializer(),
|
||||||
|
ZImageUnit_InputImageEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_z_image
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
# Initialize pipeline
|
||||||
|
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
# Fetch models
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||||
|
pipe.dit = model_pool.fetch_model("z_image_dit")
|
||||||
|
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
||||||
|
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
|
|
||||||
|
# VRAM Management
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 1.0,
|
||||||
|
# Image
|
||||||
|
input_image: Image.Image = None,
|
||||||
|
denoising_strength: float = 1.0,
|
||||||
|
# Shape
|
||||||
|
height: int = 1024,
|
||||||
|
width: int = 1024,
|
||||||
|
# Randomness
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cpu",
|
||||||
|
# Steps
|
||||||
|
num_inference_steps: int = 8,
|
||||||
|
# Progress bar
|
||||||
|
progress_bar_cmd = tqdm,
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
inputs_posi = {
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
inputs_nega = {
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
}
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
image = self.vae_decoder(inputs_shared["latents"])
|
||||||
|
image = self.vae_output_to_image(image)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageUnit_ShapeChecker(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("height", "width"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: ZImagePipeline, height, width):
|
||||||
|
height, width = pipe.check_resize_height_width(height, width)
|
||||||
|
return {"height": height, "width": width}
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
|
output_params=("prompt_embeds",),
|
||||||
|
onload_model_names=("text_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
pipe,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
) -> List[torch.FloatTensor]:
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt = [prompt]
|
||||||
|
|
||||||
|
for i, prompt_item in enumerate(prompt):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": prompt_item},
|
||||||
|
]
|
||||||
|
prompt_item = pipe.tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=True,
|
||||||
|
)
|
||||||
|
prompt[i] = prompt_item
|
||||||
|
|
||||||
|
text_inputs = pipe.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
text_input_ids = text_inputs.input_ids.to(device)
|
||||||
|
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||||
|
|
||||||
|
prompt_embeds = pipe.text_encoder(
|
||||||
|
input_ids=text_input_ids,
|
||||||
|
attention_mask=prompt_masks,
|
||||||
|
output_hidden_states=True,
|
||||||
|
).hidden_states[-2]
|
||||||
|
|
||||||
|
embeddings_list = []
|
||||||
|
|
||||||
|
for i in range(len(prompt_embeds)):
|
||||||
|
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||||
|
|
||||||
|
return embeddings_list
|
||||||
|
|
||||||
|
def process(self, pipe: ZImagePipeline, prompt):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||||
|
return {"prompt_embeds": prompt_embeds}
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: ZImagePipeline, height, width, seed, rand_device):
|
||||||
|
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: ZImagePipeline, input_image, noise):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise, "input_latents": None}
|
||||||
|
pipe.load_models_to_device(['vae'])
|
||||||
|
image = pipe.preprocess_image(input_image)
|
||||||
|
input_latents = pipe.vae_encoder(image)
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"latents": noise, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||||
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_z_image(
|
||||||
|
dit: ZImageDiT,
|
||||||
|
latents=None,
|
||||||
|
timestep=None,
|
||||||
|
prompt_embeds=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
latents = [rearrange(latents, "B C H W -> C B H W")]
|
||||||
|
timestep = (1000 - timestep) / 1000
|
||||||
|
model_output = dit(
|
||||||
|
latents,
|
||||||
|
timestep,
|
||||||
|
prompt_embeds,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)[0][0]
|
||||||
|
model_output = -model_output
|
||||||
|
model_output = rearrange(model_output, "C B H W -> B C H W")
|
||||||
|
return model_output
|
||||||
@@ -262,3 +262,121 @@ def FluxVAEDecoderStateDictConverter(state_dict):
|
|||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
state_dict_[rename_dict[name]] = param
|
state_dict_[rename_dict[name]] = param
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def FluxVAEEncoderStateDictConverterDiffusers(state_dict):
|
||||||
|
# architecture
|
||||||
|
block_types = [
|
||||||
|
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||||
|
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||||
|
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||||
|
'ResnetBlock', 'ResnetBlock',
|
||||||
|
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Rename each parameter
|
||||||
|
local_rename_dict = {
|
||||||
|
"quant_conv": "quant_conv",
|
||||||
|
"encoder.conv_in": "conv_in",
|
||||||
|
"encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
|
||||||
|
"encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
|
||||||
|
"encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
|
||||||
|
"encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
|
||||||
|
"encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
|
||||||
|
"encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
|
||||||
|
"encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
|
||||||
|
"encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
|
||||||
|
"encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
|
||||||
|
"encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
|
||||||
|
"encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
|
||||||
|
"encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
|
||||||
|
"encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
|
||||||
|
"encoder.conv_norm_out": "conv_norm_out",
|
||||||
|
"encoder.conv_out": "conv_out",
|
||||||
|
}
|
||||||
|
name_list = sorted([name for name in state_dict])
|
||||||
|
rename_dict = {}
|
||||||
|
block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
||||||
|
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||||
|
for name in name_list:
|
||||||
|
names = name.split(".")
|
||||||
|
name_prefix = ".".join(names[:-1])
|
||||||
|
if name_prefix in local_rename_dict:
|
||||||
|
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
||||||
|
elif name.startswith("encoder.down_blocks"):
|
||||||
|
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
||||||
|
block_type_with_id = ".".join(names[:5])
|
||||||
|
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||||
|
block_id[block_type] += 1
|
||||||
|
last_block_type_with_id[block_type] = block_type_with_id
|
||||||
|
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||||
|
block_id[block_type] += 1
|
||||||
|
block_type_with_id = ".".join(names[:5])
|
||||||
|
names = ["blocks", str(block_id[block_type])] + names[5:]
|
||||||
|
rename_dict[name] = ".".join(names)
|
||||||
|
|
||||||
|
# Convert state_dict
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
state_dict_[rename_dict[name]] = state_dict[name]
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def FluxVAEDecoderStateDictConverterDiffusers(state_dict):
|
||||||
|
# architecture
|
||||||
|
block_types = [
|
||||||
|
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
|
||||||
|
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||||
|
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||||
|
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||||
|
'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Rename each parameter
|
||||||
|
local_rename_dict = {
|
||||||
|
"post_quant_conv": "post_quant_conv",
|
||||||
|
"decoder.conv_in": "conv_in",
|
||||||
|
"decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
|
||||||
|
"decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
|
||||||
|
"decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
|
||||||
|
"decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
|
||||||
|
"decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
|
||||||
|
"decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
|
||||||
|
"decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
|
||||||
|
"decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
|
||||||
|
"decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
|
||||||
|
"decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
|
||||||
|
"decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
|
||||||
|
"decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
|
||||||
|
"decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
|
||||||
|
"decoder.conv_norm_out": "conv_norm_out",
|
||||||
|
"decoder.conv_out": "conv_out",
|
||||||
|
}
|
||||||
|
name_list = sorted([name for name in state_dict])
|
||||||
|
rename_dict = {}
|
||||||
|
block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
|
||||||
|
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||||
|
for name in name_list:
|
||||||
|
names = name.split(".")
|
||||||
|
name_prefix = ".".join(names[:-1])
|
||||||
|
if name_prefix in local_rename_dict:
|
||||||
|
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
||||||
|
elif name.startswith("decoder.up_blocks"):
|
||||||
|
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
||||||
|
block_type_with_id = ".".join(names[:5])
|
||||||
|
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||||
|
block_id[block_type] += 1
|
||||||
|
last_block_type_with_id[block_type] = block_type_with_id
|
||||||
|
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||||
|
block_id[block_type] += 1
|
||||||
|
block_type_with_id = ".".join(names[:5])
|
||||||
|
names = ["blocks", str(block_id[block_type])] + names[5:]
|
||||||
|
rename_dict[name] = ".".join(names)
|
||||||
|
|
||||||
|
# Convert state_dict
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
state_dict_[rename_dict[name]] = state_dict[name]
|
||||||
|
return state_dict_
|
||||||
17
examples/z_image/model_inference/Z-Image-Turbo.py
Normal file
17
examples/z_image/model_inference/Z-Image-Turbo.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||||
|
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
||||||
|
image.save("image.jpg")
|
||||||
15
examples/z_image/model_training/lora/Z-Image-Turbo.sh
Normal file
15
examples/z_image/model_training/lora/Z-Image-Turbo.sh
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
accelerate launch examples/z_image/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_image_dataset \
|
||||||
|
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Z-Image-Turbo_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8
|
||||||
143
examples/z_image/model_training/train.py
Normal file
143
examples/z_image/model_training/train.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import torch, os, argparse, accelerate
|
||||||
|
from diffsynth.core import UnifiedDataset
|
||||||
|
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||||
|
from diffsynth.diffusion import *
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageTrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_paths=None, model_id_with_origin_paths=None,
|
||||||
|
tokenizer_path=None,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
|
use_gradient_checkpointing=True,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
extra_inputs=None,
|
||||||
|
fp8_models=None,
|
||||||
|
offload_models=None,
|
||||||
|
device="cpu",
|
||||||
|
task="sft",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Load models
|
||||||
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||||
|
tokenizer_config = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||||
|
self.pipe = ZImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||||
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
|
||||||
|
# Training mode
|
||||||
|
self.switch_pipe_to_training_mode(
|
||||||
|
self.pipe, trainable_models,
|
||||||
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||||
|
preset_lora_path, preset_lora_model,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other configs
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
|
self.fp8_models = fp8_models
|
||||||
|
self.task = task
|
||||||
|
self.task_to_loss = {
|
||||||
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
|
"direct_distill:data_process": lambda pipe, *args: args,
|
||||||
|
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_pipeline_inputs(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
inputs_shared = {
|
||||||
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
|
"embedded_guidance": 1.0,
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"rand_device": self.pipe.device,
|
||||||
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
|
}
|
||||||
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
def forward(self, data, inputs=None):
|
||||||
|
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||||
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||||
|
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def qwen_image_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser = add_general_config(parser)
|
||||||
|
parser = add_image_size_config(parser)
|
||||||
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = qwen_image_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
accelerator = accelerate.Accelerator(
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||||
|
)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=args.dataset_repeat,
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=16,
|
||||||
|
width_division_factor=16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = ZImageTrainingModule(
|
||||||
|
model_paths=args.model_paths,
|
||||||
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
tokenizer_path=args.tokenizer_path,
|
||||||
|
trainable_models=args.trainable_models,
|
||||||
|
lora_base_model=args.lora_base_model,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_checkpoint=args.lora_checkpoint,
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
preset_lora_model=args.preset_lora_model,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
extra_inputs=args.extra_inputs,
|
||||||
|
fp8_models=args.fp8_models,
|
||||||
|
offload_models=args.offload_models,
|
||||||
|
task=args.task,
|
||||||
|
device=accelerator.device,
|
||||||
|
)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
args.output_path,
|
||||||
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
|
)
|
||||||
|
launcher_map = {
|
||||||
|
"sft:data_process": launch_data_process_task,
|
||||||
|
"direct_distill:data_process": launch_data_process_task,
|
||||||
|
"sft": launch_training_task,
|
||||||
|
"sft:train": launch_training_task,
|
||||||
|
"direct_distill": launch_training_task,
|
||||||
|
"direct_distill:train": launch_training_task,
|
||||||
|
}
|
||||||
|
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, "/models/train/Z-Image-Turbo_lora/epoch-4.safetensors")
|
||||||
|
prompt = "a dog"
|
||||||
|
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
||||||
|
image.save("image.jpg")
|
||||||
Reference in New Issue
Block a user