mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support z-image-omni-base
This commit is contained in:
@@ -527,6 +527,19 @@ z_image_series = [
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
||||
"extra_kwargs": {"use_conv_attention": False},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
|
||||
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
|
||||
"model_name": "z_image_dit",
|
||||
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||
"extra_kwargs": {"siglip_feat_dim": 1152},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
|
||||
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
|
||||
"model_name": "siglip_vision_model_428m",
|
||||
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
|
||||
},
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
|
||||
from transformers import SiglipImageProcessor
|
||||
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
|
||||
import torch
|
||||
|
||||
|
||||
@@ -68,3 +68,68 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
|
||||
pooler_output = self.head(last_hidden_state) if self.use_head else None
|
||||
|
||||
return pooler_output
|
||||
|
||||
|
||||
class Siglip2ImageEncoder428M(Siglip2VisionModel):
|
||||
def __init__(self):
|
||||
config = Siglip2VisionConfig(
|
||||
attention_dropout = 0.0,
|
||||
dtype = "bfloat16",
|
||||
hidden_act = "gelu_pytorch_tanh",
|
||||
hidden_size = 1152,
|
||||
intermediate_size = 4304,
|
||||
layer_norm_eps = 1e-06,
|
||||
model_type = "siglip2_vision_model",
|
||||
num_attention_heads = 16,
|
||||
num_channels = 3,
|
||||
num_hidden_layers = 27,
|
||||
num_patches = 256,
|
||||
patch_size = 16,
|
||||
transformers_version = "4.57.1"
|
||||
)
|
||||
super().__init__(config)
|
||||
self.processor = Siglip2ImageProcessorFast(
|
||||
**{
|
||||
"crop_size": None,
|
||||
"data_format": "channels_first",
|
||||
"default_to_square": True,
|
||||
"device": None,
|
||||
"disable_grouping": None,
|
||||
"do_center_crop": None,
|
||||
"do_convert_rgb": None,
|
||||
"do_normalize": True,
|
||||
"do_pad": None,
|
||||
"do_rescale": True,
|
||||
"do_resize": True,
|
||||
"image_mean": [
|
||||
0.5,
|
||||
0.5,
|
||||
0.5
|
||||
],
|
||||
"image_processor_type": "Siglip2ImageProcessorFast",
|
||||
"image_std": [
|
||||
0.5,
|
||||
0.5,
|
||||
0.5
|
||||
],
|
||||
"input_data_format": None,
|
||||
"max_num_patches": 256,
|
||||
"pad_size": None,
|
||||
"patch_size": 16,
|
||||
"processor_class": "Siglip2Processor",
|
||||
"resample": 2,
|
||||
"rescale_factor": 0.00392156862745098,
|
||||
"return_tensors": None,
|
||||
"size": None
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
||||
siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device)
|
||||
shape = siglip_inputs.spatial_shapes[0]
|
||||
hidden_state = super().forward(**siglip_inputs).last_hidden_state
|
||||
B, N, C = hidden_state.shape
|
||||
hidden_state = hidden_state[:, : shape[0] * shape[1]]
|
||||
hidden_state = hidden_state.view(shape[0], shape[1], C)
|
||||
hidden_state = hidden_state.to(torch_dtype)
|
||||
return hidden_state
|
||||
|
||||
@@ -13,6 +13,7 @@ from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
ADALN_EMBED_DIM = 256
|
||||
SEQ_MULTI_OF = 32
|
||||
X_PAD_DIM = 64
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
@@ -86,7 +87,7 @@ class Attention(torch.nn.Module):
|
||||
self.norm_q = RMSNorm(head_dim, eps=1e-5)
|
||||
self.norm_k = RMSNorm(head_dim, eps=1e-5)
|
||||
|
||||
def forward(self, hidden_states, freqs_cis):
|
||||
def forward(self, hidden_states, freqs_cis, attention_mask):
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
@@ -123,6 +124,7 @@ class Attention(torch.nn.Module):
|
||||
key,
|
||||
value,
|
||||
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
||||
attn_mask=attention_mask,
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
@@ -136,6 +138,20 @@ class Attention(torch.nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def select_per_token(
|
||||
value_noisy: torch.Tensor,
|
||||
value_clean: torch.Tensor,
|
||||
noise_mask: torch.Tensor,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
|
||||
return torch.where(
|
||||
noise_mask_expanded == 1,
|
||||
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
)
|
||||
|
||||
|
||||
class ZImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -180,40 +196,53 @@ class ZImageTransformerBlock(nn.Module):
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
noise_mask: Optional[torch.Tensor] = None,
|
||||
adaln_noisy: Optional[torch.Tensor] = None,
|
||||
adaln_clean: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation: different modulation for noisy/clean tokens
|
||||
mod_noisy = self.adaLN_modulation(adaln_noisy)
|
||||
mod_clean = self.adaLN_modulation(adaln_clean)
|
||||
|
||||
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
|
||||
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
|
||||
|
||||
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
|
||||
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
|
||||
|
||||
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
|
||||
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
|
||||
|
||||
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
|
||||
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
|
||||
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
|
||||
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Global modulation: same modulation for all tokens (avoid double select)
|
||||
mod = self.adaLN_modulation(adaln_input)
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x) * scale_msa,
|
||||
freqs_cis=freqs_cis,
|
||||
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
|
||||
)
|
||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + gate_mlp * self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x) * scale_mlp,
|
||||
)
|
||||
)
|
||||
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
|
||||
else:
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x),
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
|
||||
x = x + self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x),
|
||||
)
|
||||
)
|
||||
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
||||
|
||||
return x
|
||||
|
||||
@@ -229,9 +258,21 @@ class FinalLayer(nn.Module):
|
||||
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
x = self.norm_final(x) * scale.unsqueeze(1)
|
||||
def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation
|
||||
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
|
||||
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
|
||||
scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Original global modulation
|
||||
assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
scale = scale.unsqueeze(1)
|
||||
|
||||
x = self.norm_final(x) * scale
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
@@ -299,6 +340,7 @@ class ZImageDiT(nn.Module):
|
||||
t_scale=1000.0,
|
||||
axes_dims=[32, 48, 48],
|
||||
axes_lens=[1024, 512, 512],
|
||||
siglip_feat_dim=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
@@ -359,6 +401,32 @@ class ZImageDiT(nn.Module):
|
||||
nn.Linear(cap_feat_dim, dim, bias=True),
|
||||
)
|
||||
|
||||
# Optional SigLIP components (for Omni variant)
|
||||
self.siglip_feat_dim = siglip_feat_dim
|
||||
if siglip_feat_dim is not None:
|
||||
self.siglip_embedder = nn.Sequential(
|
||||
RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)
|
||||
)
|
||||
self.siglip_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
2000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
else:
|
||||
self.siglip_embedder = None
|
||||
self.siglip_refiner = None
|
||||
self.siglip_pad_token = None
|
||||
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
|
||||
@@ -375,22 +443,57 @@ class ZImageDiT(nn.Module):
|
||||
|
||||
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
||||
|
||||
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
|
||||
def unpatchify(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
size: List[Tuple],
|
||||
patch_size = 2,
|
||||
f_patch_size = 1,
|
||||
x_pos_offsets: Optional[List[Tuple[int, int]]] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
bsz = len(x)
|
||||
assert len(size) == bsz
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
if x_pos_offsets is not None:
|
||||
# Omni: extract target image from unified sequence (cond_images + target)
|
||||
result = []
|
||||
for i in range(bsz):
|
||||
unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]
|
||||
cu_len = 0
|
||||
x_item = None
|
||||
for j in range(len(size[i])):
|
||||
if size[i][j] is None:
|
||||
ori_len = 0
|
||||
pad_len = SEQ_MULTI_OF
|
||||
cu_len += pad_len + ori_len
|
||||
else:
|
||||
F, H, W = size[i][j]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
x_item = (
|
||||
unified_x[cu_len : cu_len + ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
cu_len += ori_len + pad_len
|
||||
result.append(x_item) # Return only the last (target) image
|
||||
return result
|
||||
else:
|
||||
# Original mode: simple unpatchify
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
@@ -405,8 +508,8 @@ class ZImageDiT(nn.Module):
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
all_cap_feats: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
patch_size: int = 2,
|
||||
f_patch_size: int = 1,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
@@ -490,90 +593,421 @@ class ZImageDiT(nn.Module):
|
||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||
all_image_out.append(image_padded_feat)
|
||||
|
||||
return all_image_out, all_cap_feats_out, {
|
||||
"x_size": all_image_size,
|
||||
"x_pos_ids": all_image_pos_ids,
|
||||
"cap_pos_ids": all_cap_pos_ids,
|
||||
"x_pad_mask": all_image_pad_mask,
|
||||
"cap_pad_mask": all_cap_pad_mask
|
||||
}
|
||||
# (
|
||||
# all_img_out,
|
||||
# all_cap_out,
|
||||
# all_img_size,
|
||||
# all_img_pos_ids,
|
||||
# all_cap_pos_ids,
|
||||
# all_img_pad_mask,
|
||||
# all_cap_pad_mask,
|
||||
# )
|
||||
|
||||
def _prepare_sequence(
|
||||
self,
|
||||
feats: List[torch.Tensor],
|
||||
pos_ids: List[torch.Tensor],
|
||||
inner_pad_mask: List[torch.Tensor],
|
||||
pad_token: torch.nn.Parameter,
|
||||
noise_mask: Optional[List[List[int]]] = None,
|
||||
device: torch.device = None,
|
||||
):
|
||||
"""Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask."""
|
||||
item_seqlens = [len(f) for f in feats]
|
||||
max_seqlen = max(item_seqlens)
|
||||
bsz = len(feats)
|
||||
|
||||
# Pad token
|
||||
feats_cat = torch.cat(feats, dim=0)
|
||||
feats_cat[torch.cat(inner_pad_mask)] = pad_token
|
||||
feats = list(feats_cat.split(item_seqlens, dim=0))
|
||||
|
||||
# RoPE
|
||||
freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))
|
||||
|
||||
# Pad to batch
|
||||
feats = pad_sequence(feats, batch_first=True, padding_value=0.0)
|
||||
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
|
||||
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(item_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
if noise_mask is not None:
|
||||
noise_mask_tensor = pad_sequence(
|
||||
[torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)[:, : feats.shape[1]]
|
||||
|
||||
return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor
|
||||
|
||||
def _build_unified_sequence(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_freqs: torch.Tensor,
|
||||
x_seqlens: List[int],
|
||||
x_noise_mask: Optional[List[List[int]]],
|
||||
cap: torch.Tensor,
|
||||
cap_freqs: torch.Tensor,
|
||||
cap_seqlens: List[int],
|
||||
cap_noise_mask: Optional[List[List[int]]],
|
||||
siglip: Optional[torch.Tensor],
|
||||
siglip_freqs: Optional[torch.Tensor],
|
||||
siglip_seqlens: Optional[List[int]],
|
||||
siglip_noise_mask: Optional[List[List[int]]],
|
||||
omni_mode: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Build unified sequence: x, cap, and optionally siglip.
|
||||
Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]
|
||||
"""
|
||||
bsz = len(x_seqlens)
|
||||
unified = []
|
||||
unified_freqs = []
|
||||
unified_noise_mask = []
|
||||
|
||||
for i in range(bsz):
|
||||
x_len, cap_len = x_seqlens[i], cap_seqlens[i]
|
||||
|
||||
if omni_mode:
|
||||
# Omni: [cap, x, siglip]
|
||||
if siglip is not None and siglip_seqlens is not None:
|
||||
sig_len = siglip_seqlens[i]
|
||||
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))
|
||||
unified_freqs.append(
|
||||
torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])
|
||||
)
|
||||
unified_noise_mask.append(
|
||||
torch.tensor(
|
||||
cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device
|
||||
)
|
||||
)
|
||||
else:
|
||||
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))
|
||||
unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))
|
||||
unified_noise_mask.append(
|
||||
torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)
|
||||
)
|
||||
else:
|
||||
# Basic: [x, cap]
|
||||
unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))
|
||||
unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))
|
||||
|
||||
# Compute unified seqlens
|
||||
if omni_mode:
|
||||
if siglip is not None and siglip_seqlens is not None:
|
||||
unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]
|
||||
else:
|
||||
unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]
|
||||
else:
|
||||
unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]
|
||||
|
||||
max_seqlen = max(unified_seqlens)
|
||||
|
||||
# Pad to batch
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
|
||||
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
if omni_mode:
|
||||
noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[
|
||||
:, : unified.shape[1]
|
||||
]
|
||||
|
||||
return unified, unified_freqs, attn_mask, noise_mask_tensor
|
||||
|
||||
def _pad_with_ids(
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
pos_grid_size: Tuple,
|
||||
pos_start: Tuple,
|
||||
device: torch.device,
|
||||
noise_mask_val: Optional[int] = None,
|
||||
):
|
||||
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
|
||||
ori_len = len(feat)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
total_len = ori_len + pad_len
|
||||
|
||||
# Pos IDs
|
||||
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
|
||||
if pad_len > 0:
|
||||
pad_pos_ids = (
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(pad_len, 1)
|
||||
)
|
||||
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
|
||||
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
|
||||
pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(ori_len, dtype=torch.bool, device=device),
|
||||
torch.ones(pad_len, dtype=torch.bool, device=device),
|
||||
]
|
||||
)
|
||||
else:
|
||||
pos_ids = ori_pos_ids
|
||||
padded_feat = feat
|
||||
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
|
||||
|
||||
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
|
||||
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
|
||||
|
||||
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
|
||||
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
|
||||
pH, pW, pF = patch_size, patch_size, f_patch_size
|
||||
C, F, H, W = image.size()
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
|
||||
|
||||
def patchify_and_embed_omni(
|
||||
self,
|
||||
all_x: List[List[torch.Tensor]],
|
||||
all_cap_feats: List[List[torch.Tensor]],
|
||||
all_siglip_feats: List[List[torch.Tensor]],
|
||||
patch_size: int = 2,
|
||||
f_patch_size: int = 1,
|
||||
images_noise_mask: List[List[int]] = None,
|
||||
):
|
||||
"""Patchify for omni mode: multiple images per batch item with noise masks."""
|
||||
bsz = len(all_x)
|
||||
device = all_x[0][-1].device
|
||||
dtype = all_x[0][-1].dtype
|
||||
|
||||
all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []
|
||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []
|
||||
all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []
|
||||
|
||||
for i in range(bsz):
|
||||
num_images = len(all_x[i])
|
||||
cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []
|
||||
cap_end_pos = []
|
||||
cap_cu_len = 1
|
||||
|
||||
# Process captions
|
||||
for j, cap_item in enumerate(all_cap_feats[i]):
|
||||
noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1
|
||||
cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(
|
||||
cap_item,
|
||||
(len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),
|
||||
(cap_cu_len, 0, 0),
|
||||
device,
|
||||
noise_val,
|
||||
)
|
||||
cap_feats_list.append(cap_out)
|
||||
cap_pos_list.append(cap_pos)
|
||||
cap_mask_list.append(cap_mask)
|
||||
cap_lens.append(cap_len)
|
||||
cap_noise.extend(cap_nm)
|
||||
cap_cu_len += len(cap_item)
|
||||
cap_end_pos.append(cap_cu_len)
|
||||
cap_cu_len += 2 # for image vae and siglip tokens
|
||||
|
||||
all_cap_out.append(torch.cat(cap_feats_list, dim=0))
|
||||
all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))
|
||||
all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))
|
||||
all_cap_len.append(cap_lens)
|
||||
all_cap_noise_mask.append(cap_noise)
|
||||
|
||||
# Process images
|
||||
x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []
|
||||
for j, x_item in enumerate(all_x[i]):
|
||||
noise_val = images_noise_mask[i][j]
|
||||
if x_item is not None:
|
||||
x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)
|
||||
x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(
|
||||
x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val
|
||||
)
|
||||
x_size.append(size)
|
||||
else:
|
||||
x_len = SEQ_MULTI_OF
|
||||
x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)
|
||||
x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)
|
||||
x_mask = torch.ones(x_len, dtype=torch.bool, device=device)
|
||||
x_nm = [noise_val] * x_len
|
||||
x_size.append(None)
|
||||
x_feats_list.append(x_out)
|
||||
x_pos_list.append(x_pos)
|
||||
x_mask_list.append(x_mask)
|
||||
x_lens.append(x_len)
|
||||
x_noise.extend(x_nm)
|
||||
|
||||
all_x_out.append(torch.cat(x_feats_list, dim=0))
|
||||
all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))
|
||||
all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))
|
||||
all_x_size.append(x_size)
|
||||
all_x_len.append(x_lens)
|
||||
all_x_noise_mask.append(x_noise)
|
||||
|
||||
# Process siglip
|
||||
if all_siglip_feats[i] is None:
|
||||
all_sig_len.append([0] * num_images)
|
||||
all_sig_out.append(None)
|
||||
else:
|
||||
sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []
|
||||
for j, sig_item in enumerate(all_siglip_feats[i]):
|
||||
noise_val = images_noise_mask[i][j]
|
||||
if sig_item is not None:
|
||||
sig_H, sig_W, sig_C = sig_item.size()
|
||||
sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)
|
||||
sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(
|
||||
sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val
|
||||
)
|
||||
# Scale position IDs to match x resolution
|
||||
if x_size[j] is not None:
|
||||
sig_pos = sig_pos.float()
|
||||
sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)
|
||||
sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)
|
||||
sig_pos = sig_pos.to(torch.int32)
|
||||
else:
|
||||
sig_len = SEQ_MULTI_OF
|
||||
sig_out = torch.zeros((sig_len, self.siglip_feat_dim), dtype=dtype, device=device)
|
||||
sig_pos = (
|
||||
self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)
|
||||
)
|
||||
sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)
|
||||
sig_nm = [noise_val] * sig_len
|
||||
sig_feats_list.append(sig_out)
|
||||
sig_pos_list.append(sig_pos)
|
||||
sig_mask_list.append(sig_mask)
|
||||
sig_lens.append(sig_len)
|
||||
sig_noise.extend(sig_nm)
|
||||
|
||||
all_sig_out.append(torch.cat(sig_feats_list, dim=0))
|
||||
all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))
|
||||
all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))
|
||||
all_sig_len.append(sig_lens)
|
||||
all_sig_noise_mask.append(sig_noise)
|
||||
|
||||
# Compute x position offsets
|
||||
all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_cap_feats_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_x_out,
|
||||
all_cap_out,
|
||||
all_sig_out,
|
||||
all_x_size,
|
||||
all_x_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_image_pad_mask,
|
||||
all_sig_pos_ids,
|
||||
all_x_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
all_sig_pad_mask,
|
||||
all_x_pos_offsets,
|
||||
all_x_noise_mask,
|
||||
all_cap_noise_mask,
|
||||
all_sig_noise_mask,
|
||||
)
|
||||
return all_x_out, all_cap_out, all_sig_out, {
|
||||
"x_size": x_size,
|
||||
"x_pos_ids": all_x_pos_ids,
|
||||
"cap_pos_ids": all_cap_pos_ids,
|
||||
"sig_pos_ids": all_sig_pos_ids,
|
||||
"x_pad_mask": all_x_pad_mask,
|
||||
"cap_pad_mask": all_cap_pad_mask,
|
||||
"sig_pad_mask": all_sig_pad_mask,
|
||||
"x_pos_offsets": all_x_pos_offsets,
|
||||
"x_noise_mask": all_x_noise_mask,
|
||||
"cap_noise_mask": all_cap_noise_mask,
|
||||
"sig_noise_mask": all_sig_noise_mask,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
t,
|
||||
cap_feats: List[torch.Tensor],
|
||||
siglip_feats = None,
|
||||
image_noise_mask = None,
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
assert patch_size in self.all_patch_size
|
||||
assert f_patch_size in self.all_f_patch_size
|
||||
assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size
|
||||
omni_mode = isinstance(x[0], list)
|
||||
device = x[0][-1].device if omni_mode else x[0].device
|
||||
|
||||
bsz = len(x)
|
||||
device = x[0].device
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
if omni_mode:
|
||||
# Dual embeddings: noisy (t) and clean (t=1)
|
||||
t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])
|
||||
t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])
|
||||
adaln_input = None
|
||||
else:
|
||||
# Single embedding for all tokens
|
||||
adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])
|
||||
t_noisy = t_clean = None
|
||||
|
||||
adaln_input = t
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_inner_pad_mask,
|
||||
cap_inner_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
# Patchify
|
||||
if omni_mode:
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
siglip_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
siglip_pos_ids,
|
||||
x_pad_mask,
|
||||
cap_pad_mask,
|
||||
siglip_pad_mask,
|
||||
x_pos_offsets,
|
||||
x_noise_mask,
|
||||
cap_noise_mask,
|
||||
siglip_noise_mask,
|
||||
) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)
|
||||
else:
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_pad_mask,
|
||||
cap_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None
|
||||
|
||||
# x embed & refine
|
||||
x_item_seqlens = [len(_) for _ in x]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||
x_max_item_seqlen = max(x_item_seqlens)
|
||||
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token.to(dtype=x.dtype, device=x.device)
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
x_seqlens = [len(xi) for xi in x]
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed
|
||||
x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(
|
||||
list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device
|
||||
)
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
x = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=x,
|
||||
attn_mask=x_attn_mask,
|
||||
freqs_cis=x_freqs_cis,
|
||||
adaln_input=adaln_input,
|
||||
x=x, attn_mask=x_mask, freqs_cis=x_freqs, adaln_input=adaln_input, noise_mask=x_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean,
|
||||
)
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
|
||||
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||
|
||||
cap_feats = torch.cat(cap_feats, dim=0)
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token.to(dtype=x.dtype, device=x.device)
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
||||
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
# Cap embed & refine
|
||||
cap_seqlens = [len(ci) for ci in cap_feats]
|
||||
cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed
|
||||
cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(
|
||||
list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device
|
||||
)
|
||||
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = gradient_checkpoint_forward(
|
||||
@@ -581,41 +1015,68 @@ class ZImageDiT(nn.Module):
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=cap_feats,
|
||||
attn_mask=cap_attn_mask,
|
||||
freqs_cis=cap_freqs_cis,
|
||||
attn_mask=cap_mask,
|
||||
freqs_cis=cap_freqs,
|
||||
)
|
||||
|
||||
# unified
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
# Siglip embed & refine
|
||||
siglip_seqlens = siglip_freqs = None
|
||||
if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:
|
||||
siglip_seqlens = [len(si) for si in siglip_feats]
|
||||
siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed
|
||||
siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(
|
||||
list(siglip_feats.split(siglip_seqlens, dim=0)),
|
||||
siglip_pos_ids,
|
||||
siglip_pad_mask,
|
||||
self.siglip_pad_token,
|
||||
None,
|
||||
device,
|
||||
)
|
||||
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
for layer in self.siglip_refiner:
|
||||
siglip_feats = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=siglip_feats, attn_mask=siglip_mask, freqs_cis=siglip_freqs,
|
||||
)
|
||||
|
||||
for layer in self.layers:
|
||||
# Unified sequence
|
||||
unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(
|
||||
x,
|
||||
x_freqs,
|
||||
x_seqlens,
|
||||
x_noise_mask,
|
||||
cap_feats,
|
||||
cap_freqs,
|
||||
cap_seqlens,
|
||||
cap_noise_mask,
|
||||
siglip_feats,
|
||||
siglip_freqs,
|
||||
siglip_seqlens,
|
||||
siglip_noise_mask,
|
||||
omni_mode,
|
||||
device,
|
||||
)
|
||||
|
||||
# Main transformer layers
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
unified = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=unified,
|
||||
attn_mask=unified_attn_mask,
|
||||
freqs_cis=unified_freqs_cis,
|
||||
adaln_input=adaln_input,
|
||||
x=unified, attn_mask=unified_mask, freqs_cis=unified_freqs, adaln_input=adaln_input, noise_mask=unified_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean
|
||||
)
|
||||
|
||||
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||
unified = list(unified.unbind(dim=0))
|
||||
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
||||
unified = (
|
||||
self.all_final_layer[f"{patch_size}-{f_patch_size}"](
|
||||
unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean
|
||||
)
|
||||
if omni_mode
|
||||
else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input)
|
||||
)
|
||||
|
||||
return x, {}
|
||||
# Unpatchify
|
||||
x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)
|
||||
|
||||
return x
|
||||
|
||||
@@ -4,16 +4,18 @@ from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from typing import Union, List, Optional, Tuple
|
||||
from typing import Union, List, Optional, Tuple, Iterable
|
||||
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..core.data.operators import ImageCropAndResize
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||
from ..models.z_image_dit import ZImageDiT
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
|
||||
|
||||
|
||||
class ZImagePipeline(BasePipeline):
|
||||
@@ -28,6 +30,7 @@ class ZImagePipeline(BasePipeline):
|
||||
self.dit: ZImageDiT = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.image_encoder: Siglip2ImageEncoder428M = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
@@ -35,6 +38,9 @@ class ZImagePipeline(BasePipeline):
|
||||
ZImageUnit_PromptEmbedder(),
|
||||
ZImageUnit_NoiseInitializer(),
|
||||
ZImageUnit_InputImageEmbedder(),
|
||||
ZImageUnit_EditImageAutoResize(),
|
||||
ZImageUnit_EditImageEmbedderVAE(),
|
||||
ZImageUnit_EditImageEmbedderSiglip(),
|
||||
]
|
||||
self.model_fn = model_fn_z_image
|
||||
|
||||
@@ -56,6 +62,7 @@ class ZImagePipeline(BasePipeline):
|
||||
pipe.dit = model_pool.fetch_model("z_image_dit")
|
||||
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
||||
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
||||
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
@@ -75,6 +82,9 @@ class ZImagePipeline(BasePipeline):
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Edit
|
||||
edit_image: Image.Image = None,
|
||||
edit_image_auto_resize: bool = True,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
@@ -83,11 +93,12 @@ class ZImagePipeline(BasePipeline):
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 8,
|
||||
sigma_shift: float = None,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
@@ -102,6 +113,7 @@ class ZImagePipeline(BasePipeline):
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -143,12 +155,13 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params=("edit_image",),
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
pipe,
|
||||
@@ -194,10 +207,81 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
|
||||
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||
|
||||
return embeddings_list
|
||||
|
||||
def encode_prompt_omni(
|
||||
self,
|
||||
pipe,
|
||||
prompt: Union[str, List[str]],
|
||||
edit_image=None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.FloatTensor]:
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
def process(self, pipe: ZImagePipeline, prompt):
|
||||
if edit_image is None:
|
||||
num_condition_images = 0
|
||||
elif isinstance(edit_image, list):
|
||||
num_condition_images = len(edit_image)
|
||||
else:
|
||||
num_condition_images = 1
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
if num_condition_images == 0:
|
||||
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
|
||||
elif num_condition_images > 0:
|
||||
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
|
||||
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|im_end|>"]
|
||||
prompt[i] = prompt_list
|
||||
|
||||
flattened_prompt = []
|
||||
prompt_list_lengths = []
|
||||
|
||||
for i in range(len(prompt)):
|
||||
prompt_list_lengths.append(len(prompt[i]))
|
||||
flattened_prompt.extend(prompt[i])
|
||||
|
||||
text_inputs = pipe.tokenizer(
|
||||
flattened_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = pipe.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
start_idx = 0
|
||||
for i in range(len(prompt_list_lengths)):
|
||||
batch_embeddings = []
|
||||
end_idx = start_idx + prompt_list_lengths[i]
|
||||
for j in range(start_idx, end_idx):
|
||||
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
|
||||
embeddings_list.append(batch_embeddings)
|
||||
start_idx = end_idx
|
||||
|
||||
return embeddings_list
|
||||
|
||||
def process(self, pipe: ZImagePipeline, prompt, edit_image):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||
if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
|
||||
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
|
||||
# We determine which encoding method to use based on the model architecture.
|
||||
# If you are using two-stage split training,
|
||||
# please use `--offload_models` instead of skipping the DiT model loading.
|
||||
prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device)
|
||||
else:
|
||||
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
|
||||
@@ -234,24 +318,197 @@ class ZImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
class ZImageUnit_EditImageAutoResize(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image", "edit_image_auto_resize"),
|
||||
output_params=("edit_image",),
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
if edit_image_auto_resize is None or not edit_image_auto_resize:
|
||||
return {}
|
||||
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
|
||||
edit_image = operator(edit_image)
|
||||
return {"edit_image": edit_image}
|
||||
|
||||
|
||||
class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image",),
|
||||
output_params=("image_embeds",),
|
||||
onload_model_names=("image_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, edit_image):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if not isinstance(edit_image, list):
|
||||
edit_image = [edit_image]
|
||||
image_emb = []
|
||||
for image_ in edit_image:
|
||||
image_emb.append(pipe.image_encoder(image_, device=pipe.device))
|
||||
return {"image_embeds": image_emb}
|
||||
|
||||
|
||||
class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image",),
|
||||
output_params=("image_latents",),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, edit_image):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if not isinstance(edit_image, list):
|
||||
edit_image = [edit_image]
|
||||
image_latents = []
|
||||
for image_ in edit_image:
|
||||
image_ = pipe.preprocess_image(image_)
|
||||
image_latents.append(pipe.vae_encoder(image_))
|
||||
return {"image_latents": image_latents}
|
||||
|
||||
|
||||
def model_fn_z_image(
|
||||
dit: ZImageDiT,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
image_embeds=None,
|
||||
image_latents=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
# Due to the complex and verbose codebase of Z-Image,
|
||||
# we are temporarily using this inelegant structure.
|
||||
# We will refactor this part in the future (if time permits).
|
||||
if dit.siglip_embedder is None:
|
||||
return model_fn_z_image_turbo(
|
||||
dit,
|
||||
latents,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
image_embeds,
|
||||
image_latents,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
**kwargs,
|
||||
)
|
||||
latents = [rearrange(latents, "B C H W -> C B H W")]
|
||||
if dit.siglip_embedder is not None:
|
||||
if image_latents is not None:
|
||||
image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents]
|
||||
latents = [image_latents + latents]
|
||||
image_noise_mask = [[0] * len(image_latents) + [1]]
|
||||
else:
|
||||
latents = [latents]
|
||||
image_noise_mask = [[1]]
|
||||
image_embeds = [image_embeds]
|
||||
else:
|
||||
image_noise_mask = None
|
||||
timestep = (1000 - timestep) / 1000
|
||||
model_output = dit(
|
||||
latents,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
siglip_feats=image_embeds,
|
||||
image_noise_mask=image_noise_mask,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)[0][0]
|
||||
)[0]
|
||||
model_output = -model_output
|
||||
model_output = rearrange(model_output, "C B H W -> B C H W")
|
||||
return model_output
|
||||
|
||||
|
||||
def model_fn_z_image_turbo(
|
||||
dit: ZImageDiT,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
image_embeds=None,
|
||||
image_latents=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
while isinstance(prompt_embeds, list):
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
while isinstance(latents, list):
|
||||
latents = latents[0]
|
||||
while isinstance(image_embeds, list):
|
||||
image_embeds = image_embeds[0]
|
||||
|
||||
# Timestep
|
||||
timestep = 1000 - timestep
|
||||
t_noisy = dit.t_embedder(timestep)
|
||||
t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000)
|
||||
|
||||
# Patchify
|
||||
latents = rearrange(latents, "B C H W -> C B H W")
|
||||
x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds])
|
||||
x = x[0]
|
||||
cap_feats = cap_feats[0]
|
||||
|
||||
# Noise refine
|
||||
x = dit.all_x_embedder["2-1"](x)
|
||||
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
|
||||
x = rearrange(x, "L C -> 1 L C")
|
||||
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
|
||||
|
||||
for layer in dit.noise_refiner:
|
||||
x = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=x,
|
||||
attn_mask=None,
|
||||
freqs_cis=x_freqs_cis,
|
||||
adaln_input=t_noisy,
|
||||
)
|
||||
|
||||
# Prompt refine
|
||||
cap_feats = dit.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device)
|
||||
cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0))
|
||||
cap_feats = rearrange(cap_feats, "L C -> 1 L C")
|
||||
cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C")
|
||||
|
||||
for layer in dit.context_refiner:
|
||||
cap_feats = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=cap_feats,
|
||||
attn_mask=None,
|
||||
freqs_cis=cap_freqs_cis,
|
||||
)
|
||||
|
||||
# Unified
|
||||
unified = torch.cat([x, cap_feats], dim=1)
|
||||
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
|
||||
for layer in dit.layers:
|
||||
unified = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=unified,
|
||||
attn_mask=None,
|
||||
freqs_cis=unified_freqs_cis,
|
||||
adaln_input=t_noisy,
|
||||
)
|
||||
|
||||
# Output
|
||||
unified = dit.all_final_layer["2-1"](unified, t_noisy)
|
||||
x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0]
|
||||
x = rearrange(x, "C B H W -> B C H W")
|
||||
x = -x
|
||||
return x
|
||||
|
||||
23
examples/z_image/model_inference/Z-Image-Omni-Base.py
Normal file
23
examples/z_image/model_inference/Z-Image-Omni-Base.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4)
|
||||
image.save("image_Z-Image-Omni-Base.jpg")
|
||||
|
||||
image = Image.open("image_Z-Image-Omni-Base.jpg")
|
||||
prompt = "Change the women's clothes to white cheongsam, keep other content unchanged"
|
||||
image = pipe(prompt=prompt, edit_image=image, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4)
|
||||
image.save("image_edit_Z-Image-Omni-Base.jpg")
|
||||
Reference in New Issue
Block a user