diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 08ec023..e6c7741 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -527,6 +527,19 @@ z_image_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers", "extra_kwargs": {"use_conv_attention": False}, }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors") + "model_hash": "aa3563718e5c3ecde3dfbb020ca61180", + "model_name": "z_image_dit", + "model_class": "diffsynth.models.z_image_dit.ZImageDiT", + "extra_kwargs": {"siglip_feat_dim": 1152}, + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors") + "model_hash": "89d48e420f45cff95115a9f3e698d44a", + "model_name": "siglip_vision_model_428m", + "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M", + }, ] MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py index 10184f8..76441d0 100644 --- a/diffsynth/models/siglip2_image_encoder.py +++ b/diffsynth/models/siglip2_image_encoder.py @@ -1,5 +1,5 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig -from transformers import SiglipImageProcessor +from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast import torch @@ -68,3 +68,68 @@ class Siglip2ImageEncoder(SiglipVisionTransformer): pooler_output = self.head(last_hidden_state) if self.use_head else None return pooler_output + + +class Siglip2ImageEncoder428M(Siglip2VisionModel): + def __init__(self): + config = Siglip2VisionConfig( + attention_dropout = 0.0, + dtype = "bfloat16", + hidden_act = "gelu_pytorch_tanh", + hidden_size = 1152, + intermediate_size = 4304, + layer_norm_eps = 1e-06, + model_type = "siglip2_vision_model", + num_attention_heads = 16, + num_channels = 3, + num_hidden_layers = 27, + num_patches = 256, + patch_size = 16, + transformers_version = "4.57.1" + ) + super().__init__(config) + self.processor = Siglip2ImageProcessorFast( + **{ + "crop_size": None, + "data_format": "channels_first", + "default_to_square": True, + "device": None, + "disable_grouping": None, + "do_center_crop": None, + "do_convert_rgb": None, + "do_normalize": True, + "do_pad": None, + "do_rescale": True, + "do_resize": True, + "image_mean": [ + 0.5, + 0.5, + 0.5 + ], + "image_processor_type": "Siglip2ImageProcessorFast", + "image_std": [ + 0.5, + 0.5, + 0.5 + ], + "input_data_format": None, + "max_num_patches": 256, + "pad_size": None, + "patch_size": 16, + "processor_class": "Siglip2Processor", + "resample": 2, + "rescale_factor": 0.00392156862745098, + "return_tensors": None, + "size": None + } + ) + + def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device) + shape = siglip_inputs.spatial_shapes[0] + hidden_state = super().forward(**siglip_inputs).last_hidden_state + B, N, C = hidden_state.shape + hidden_state = hidden_state[:, : shape[0] * shape[1]] + hidden_state = hidden_state.view(shape[0], shape[1], C) + hidden_state = hidden_state.to(torch_dtype) + return hidden_state diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index 7664fc5..d0e392e 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -13,6 +13,7 @@ from ..core.gradient import gradient_checkpoint_forward ADALN_EMBED_DIM = 256 SEQ_MULTI_OF = 32 +X_PAD_DIM = 64 class TimestepEmbedder(nn.Module): @@ -86,7 +87,7 @@ class Attention(torch.nn.Module): self.norm_q = RMSNorm(head_dim, eps=1e-5) self.norm_k = RMSNorm(head_dim, eps=1e-5) - def forward(self, hidden_states, freqs_cis): + def forward(self, hidden_states, freqs_cis, attention_mask): query = self.to_q(hidden_states) key = self.to_k(hidden_states) value = self.to_v(hidden_states) @@ -123,6 +124,7 @@ class Attention(torch.nn.Module): key, value, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + attn_mask=attention_mask, ) # Reshape back @@ -136,6 +138,20 @@ class Attention(torch.nn.Module): return output +def select_per_token( + value_noisy: torch.Tensor, + value_clean: torch.Tensor, + noise_mask: torch.Tensor, + seq_len: int, +) -> torch.Tensor: + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + return torch.where( + noise_mask_expanded == 1, + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), + value_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + + class ZImageTransformerBlock(nn.Module): def __init__( self, @@ -180,40 +196,53 @@ class ZImageTransformerBlock(nn.Module): attn_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, ): if self.modulation: - assert adaln_input is not None - scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) - gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() - scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens + mod_noisy = self.adaLN_modulation(adaln_noisy) + mod_clean = self.adaLN_modulation(adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) + else: + # Global modulation: same modulation for all tokens (avoid double select) + mod = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp # Attention block attn_out = self.attention( - self.attention_norm1(x) * scale_msa, - freqs_cis=freqs_cis, + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis ) x = x + gate_msa * self.attention_norm2(attn_out) # FFN block - x = x + gate_mlp * self.ffn_norm2( - self.feed_forward( - self.ffn_norm1(x) * scale_mlp, - ) - ) + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) else: # Attention block - attn_out = self.attention( - self.attention_norm1(x), - freqs_cis=freqs_cis, - ) + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) x = x + self.attention_norm2(attn_out) # FFN block - x = x + self.ffn_norm2( - self.feed_forward( - self.ffn_norm1(x), - ) - ) + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) return x @@ -229,9 +258,21 @@ class FinalLayer(nn.Module): nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), ) - def forward(self, x, c): - scale = 1.0 + self.adaLN_modulation(c) - x = self.norm_final(x) * scale.unsqueeze(1) + def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation + scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) + scale_clean = 1.0 + self.adaLN_modulation(c_clean) + scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len) + else: + # Original global modulation + assert c is not None, "Either c or (c_noisy, c_clean) must be provided" + scale = 1.0 + self.adaLN_modulation(c) + scale = scale.unsqueeze(1) + + x = self.norm_final(x) * scale x = self.linear(x) return x @@ -299,6 +340,7 @@ class ZImageDiT(nn.Module): t_scale=1000.0, axes_dims=[32, 48, 48], axes_lens=[1024, 512, 512], + siglip_feat_dim=None, ) -> None: super().__init__() self.in_channels = in_channels @@ -359,6 +401,32 @@ class ZImageDiT(nn.Module): nn.Linear(cap_feat_dim, dim, bias=True), ) + # Optional SigLIP components (for Omni variant) + self.siglip_feat_dim = siglip_feat_dim + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True) + ) + self.siglip_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 2000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) @@ -375,22 +443,57 @@ class ZImageDiT(nn.Module): self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: + def unpatchify( + self, + x: List[torch.Tensor], + size: List[Tuple], + patch_size = 2, + f_patch_size = 1, + x_pos_offsets: Optional[List[Tuple[int, int]]] = None, + ) -> List[torch.Tensor]: pH = pW = patch_size pF = f_patch_size bsz = len(x) assert len(size) == bsz - for i in range(bsz): - F, H, W = size[i] - ori_len = (F // pF) * (H // pH) * (W // pW) - # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" - x[i] = ( - x[i][:ori_len] - .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) - .permute(6, 0, 3, 1, 4, 2, 5) - .reshape(self.out_channels, F, H, W) - ) - return x + + if x_pos_offsets is not None: + # Omni: extract target image from unified sequence (cond_images + target) + result = [] + for i in range(bsz): + unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]] + cu_len = 0 + x_item = None + for j in range(len(size[i])): + if size[i][j] is None: + ori_len = 0 + pad_len = SEQ_MULTI_OF + cu_len += pad_len + ori_len + else: + F, H, W = size[i][j] + ori_len = (F // pF) * (H // pH) * (W // pW) + pad_len = (-ori_len) % SEQ_MULTI_OF + x_item = ( + unified_x[cu_len : cu_len + ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + cu_len += ori_len + pad_len + result.append(x_item) # Return only the last (target) image + return result + else: + # Original mode: simple unpatchify + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x @staticmethod def create_coordinate_grid(size, start=None, device=None): @@ -405,8 +508,8 @@ class ZImageDiT(nn.Module): self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], - patch_size: int, - f_patch_size: int, + patch_size: int = 2, + f_patch_size: int = 1, ): pH = pW = patch_size pF = f_patch_size @@ -490,90 +593,421 @@ class ZImageDiT(nn.Module): image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) all_image_out.append(image_padded_feat) + return all_image_out, all_cap_feats_out, { + "x_size": all_image_size, + "x_pos_ids": all_image_pos_ids, + "cap_pos_ids": all_cap_pos_ids, + "x_pad_mask": all_image_pad_mask, + "cap_pad_mask": all_cap_pad_mask + } + # ( + # all_img_out, + # all_cap_out, + # all_img_size, + # all_img_pos_ids, + # all_cap_pos_ids, + # all_img_pad_mask, + # all_cap_pad_mask, + # ) + + def _prepare_sequence( + self, + feats: List[torch.Tensor], + pos_ids: List[torch.Tensor], + inner_pad_mask: List[torch.Tensor], + pad_token: torch.nn.Parameter, + noise_mask: Optional[List[List[int]]] = None, + device: torch.device = None, + ): + """Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask.""" + item_seqlens = [len(f) for f in feats] + max_seqlen = max(item_seqlens) + bsz = len(feats) + + # Pad token + feats_cat = torch.cat(feats, dim=0) + feats_cat[torch.cat(inner_pad_mask)] = pad_token + feats = list(feats_cat.split(item_seqlens, dim=0)) + + # RoPE + freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0)) + + # Pad to batch + feats = pad_sequence(feats, batch_first=True, padding_value=0.0) + freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]] + + # Attention mask + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(item_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if noise_mask is not None: + noise_mask_tensor = pad_sequence( + [torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask], + batch_first=True, + padding_value=0, + )[:, : feats.shape[1]] + + return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor + + def _build_unified_sequence( + self, + x: torch.Tensor, + x_freqs: torch.Tensor, + x_seqlens: List[int], + x_noise_mask: Optional[List[List[int]]], + cap: torch.Tensor, + cap_freqs: torch.Tensor, + cap_seqlens: List[int], + cap_noise_mask: Optional[List[List[int]]], + siglip: Optional[torch.Tensor], + siglip_freqs: Optional[torch.Tensor], + siglip_seqlens: Optional[List[int]], + siglip_noise_mask: Optional[List[List[int]]], + omni_mode: bool, + device: torch.device, + ): + """Build unified sequence: x, cap, and optionally siglip. + Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip] + """ + bsz = len(x_seqlens) + unified = [] + unified_freqs = [] + unified_noise_mask = [] + + for i in range(bsz): + x_len, cap_len = x_seqlens[i], cap_seqlens[i] + + if omni_mode: + # Omni: [cap, x, siglip] + if siglip is not None and siglip_seqlens is not None: + sig_len = siglip_seqlens[i] + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]])) + unified_freqs.append( + torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]]) + ) + unified_noise_mask.append( + torch.tensor( + cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device + ) + ) + else: + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]])) + unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]])) + unified_noise_mask.append( + torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device) + ) + else: + # Basic: [x, cap] + unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]])) + unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]])) + + # Compute unified seqlens + if omni_mode: + if siglip is not None and siglip_seqlens is not None: + unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)] + + max_seqlen = max(unified_seqlens) + + # Pad to batch + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) + + # Attention mask + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if omni_mode: + noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[ + :, : unified.shape[1] + ] + + return unified, unified_freqs, attn_mask, noise_mask_tensor + + def _pad_with_ids( + self, + feat: torch.Tensor, + pos_grid_size: Tuple, + pos_start: Tuple, + device: torch.device, + noise_mask_val: Optional[int] = None, + ): + """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" + ori_len = len(feat) + pad_len = (-ori_len) % SEQ_MULTI_OF + total_len = ori_len + pad_len + + # Pos IDs + ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) + if pad_len > 0: + pad_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(pad_len, 1) + ) + pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) + padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) + pad_mask = torch.cat( + [ + torch.zeros(ori_len, dtype=torch.bool, device=device), + torch.ones(pad_len, dtype=torch.bool, device=device), + ] + ) + else: + pos_ids = ori_pos_ids + padded_feat = feat + pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) + + noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level + return padded_feat, pos_ids, pad_mask, total_len, noise_mask + + def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): + """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" + pH, pW, pF = patch_size, patch_size, f_patch_size + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + return image, (F, H, W), (F_tokens, H_tokens, W_tokens) + + def patchify_and_embed_omni( + self, + all_x: List[List[torch.Tensor]], + all_cap_feats: List[List[torch.Tensor]], + all_siglip_feats: List[List[torch.Tensor]], + patch_size: int = 2, + f_patch_size: int = 1, + images_noise_mask: List[List[int]] = None, + ): + """Patchify for omni mode: multiple images per batch item with noise masks.""" + bsz = len(all_x) + device = all_x[0][-1].device + dtype = all_x[0][-1].dtype + + all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], [] + all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], [] + + for i in range(bsz): + num_images = len(all_x[i]) + cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], [] + cap_end_pos = [] + cap_cu_len = 1 + + # Process captions + for j, cap_item in enumerate(all_cap_feats[i]): + noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1 + cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids( + cap_item, + (len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1), + (cap_cu_len, 0, 0), + device, + noise_val, + ) + cap_feats_list.append(cap_out) + cap_pos_list.append(cap_pos) + cap_mask_list.append(cap_mask) + cap_lens.append(cap_len) + cap_noise.extend(cap_nm) + cap_cu_len += len(cap_item) + cap_end_pos.append(cap_cu_len) + cap_cu_len += 2 # for image vae and siglip tokens + + all_cap_out.append(torch.cat(cap_feats_list, dim=0)) + all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0)) + all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0)) + all_cap_len.append(cap_lens) + all_cap_noise_mask.append(cap_noise) + + # Process images + x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], [] + for j, x_item in enumerate(all_x[i]): + noise_val = images_noise_mask[i][j] + if x_item is not None: + x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size) + x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids( + x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val + ) + x_size.append(size) + else: + x_len = SEQ_MULTI_OF + x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device) + x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1) + x_mask = torch.ones(x_len, dtype=torch.bool, device=device) + x_nm = [noise_val] * x_len + x_size.append(None) + x_feats_list.append(x_out) + x_pos_list.append(x_pos) + x_mask_list.append(x_mask) + x_lens.append(x_len) + x_noise.extend(x_nm) + + all_x_out.append(torch.cat(x_feats_list, dim=0)) + all_x_pos_ids.append(torch.cat(x_pos_list, dim=0)) + all_x_pad_mask.append(torch.cat(x_mask_list, dim=0)) + all_x_size.append(x_size) + all_x_len.append(x_lens) + all_x_noise_mask.append(x_noise) + + # Process siglip + if all_siglip_feats[i] is None: + all_sig_len.append([0] * num_images) + all_sig_out.append(None) + else: + sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], [] + for j, sig_item in enumerate(all_siglip_feats[i]): + noise_val = images_noise_mask[i][j] + if sig_item is not None: + sig_H, sig_W, sig_C = sig_item.size() + sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C) + sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids( + sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val + ) + # Scale position IDs to match x resolution + if x_size[j] is not None: + sig_pos = sig_pos.float() + sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1) + sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1) + sig_pos = sig_pos.to(torch.int32) + else: + sig_len = SEQ_MULTI_OF + sig_out = torch.zeros((sig_len, self.siglip_feat_dim), dtype=dtype, device=device) + sig_pos = ( + self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1) + ) + sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device) + sig_nm = [noise_val] * sig_len + sig_feats_list.append(sig_out) + sig_pos_list.append(sig_pos) + sig_mask_list.append(sig_mask) + sig_lens.append(sig_len) + sig_noise.extend(sig_nm) + + all_sig_out.append(torch.cat(sig_feats_list, dim=0)) + all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0)) + all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0)) + all_sig_len.append(sig_lens) + all_sig_noise_mask.append(sig_noise) + + # Compute x position offsets + all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)] + return ( - all_image_out, - all_cap_feats_out, - all_image_size, - all_image_pos_ids, + all_x_out, + all_cap_out, + all_sig_out, + all_x_size, + all_x_pos_ids, all_cap_pos_ids, - all_image_pad_mask, + all_sig_pos_ids, + all_x_pad_mask, all_cap_pad_mask, + all_sig_pad_mask, + all_x_pos_offsets, + all_x_noise_mask, + all_cap_noise_mask, + all_sig_noise_mask, ) + return all_x_out, all_cap_out, all_sig_out, { + "x_size": x_size, + "x_pos_ids": all_x_pos_ids, + "cap_pos_ids": all_cap_pos_ids, + "sig_pos_ids": all_sig_pos_ids, + "x_pad_mask": all_x_pad_mask, + "cap_pad_mask": all_cap_pad_mask, + "sig_pad_mask": all_sig_pad_mask, + "x_pos_offsets": all_x_pos_offsets, + "x_noise_mask": all_x_noise_mask, + "cap_noise_mask": all_cap_noise_mask, + "sig_noise_mask": all_sig_noise_mask, + } def forward( self, x: List[torch.Tensor], t, cap_feats: List[torch.Tensor], + siglip_feats = None, + image_noise_mask = None, patch_size=2, f_patch_size=1, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, ): - assert patch_size in self.all_patch_size - assert f_patch_size in self.all_f_patch_size + assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size + omni_mode = isinstance(x[0], list) + device = x[0][-1].device if omni_mode else x[0].device - bsz = len(x) - device = x[0].device - t = t * self.t_scale - t = self.t_embedder(t) + if omni_mode: + # Dual embeddings: noisy (t) and clean (t=1) + t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1]) + t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1]) + adaln_input = None + else: + # Single embedding for all tokens + adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0]) + t_noisy = t_clean = None - adaln_input = t - - ( - x, - cap_feats, - x_size, - x_pos_ids, - cap_pos_ids, - x_inner_pad_mask, - cap_inner_pad_mask, - ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + # Patchify + if omni_mode: + ( + x, + cap_feats, + siglip_feats, + x_size, + x_pos_ids, + cap_pos_ids, + siglip_pos_ids, + x_pad_mask, + cap_pad_mask, + siglip_pad_mask, + x_pos_offsets, + x_noise_mask, + cap_noise_mask, + siglip_noise_mask, + ) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask) + else: + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_pad_mask, + cap_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None # x embed & refine - x_item_seqlens = [len(_) for _ in x] - assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) - x_max_item_seqlen = max(x_item_seqlens) - - x = torch.cat(x, dim=0) - x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) - x[torch.cat(x_inner_pad_mask)] = self.x_pad_token.to(dtype=x.dtype, device=x.device) - x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) - - x = pad_sequence(x, batch_first=True, padding_value=0.0) - x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) - x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(x_item_seqlens): - x_attn_mask[i, :seq_len] = 1 + x_seqlens = [len(xi) for xi in x] + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed + x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence( + list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device + ) for layer in self.noise_refiner: x = gradient_checkpoint_forward( layer, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, - x=x, - attn_mask=x_attn_mask, - freqs_cis=x_freqs_cis, - adaln_input=adaln_input, + x=x, attn_mask=x_mask, freqs_cis=x_freqs, adaln_input=adaln_input, noise_mask=x_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean, ) - # cap embed & refine - cap_item_seqlens = [len(_) for _ in cap_feats] - assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) - cap_max_item_seqlen = max(cap_item_seqlens) - - cap_feats = torch.cat(cap_feats, dim=0) - cap_feats = self.cap_embedder(cap_feats) - cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token.to(dtype=x.dtype, device=x.device) - cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) - - cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) - cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) - cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(cap_item_seqlens): - cap_attn_mask[i, :seq_len] = 1 + # Cap embed & refine + cap_seqlens = [len(ci) for ci in cap_feats] + cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed + cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence( + list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device + ) for layer in self.context_refiner: cap_feats = gradient_checkpoint_forward( @@ -581,41 +1015,68 @@ class ZImageDiT(nn.Module): use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, x=cap_feats, - attn_mask=cap_attn_mask, - freqs_cis=cap_freqs_cis, + attn_mask=cap_mask, + freqs_cis=cap_freqs, ) - # unified - unified = [] - unified_freqs_cis = [] - for i in range(bsz): - x_len = x_item_seqlens[i] - cap_len = cap_item_seqlens[i] - unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) - unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) - unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] - assert unified_item_seqlens == [len(_) for _ in unified] - unified_max_item_seqlen = max(unified_item_seqlens) + # Siglip embed & refine + siglip_seqlens = siglip_freqs = None + if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None: + siglip_seqlens = [len(si) for si in siglip_feats] + siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed + siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence( + list(siglip_feats.split(siglip_seqlens, dim=0)), + siglip_pos_ids, + siglip_pad_mask, + self.siglip_pad_token, + None, + device, + ) - unified = pad_sequence(unified, batch_first=True, padding_value=0.0) - unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) - unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(unified_item_seqlens): - unified_attn_mask[i, :seq_len] = 1 + for layer in self.siglip_refiner: + siglip_feats = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=siglip_feats, attn_mask=siglip_mask, freqs_cis=siglip_freqs, + ) - for layer in self.layers: + # Unified sequence + unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence( + x, + x_freqs, + x_seqlens, + x_noise_mask, + cap_feats, + cap_freqs, + cap_seqlens, + cap_noise_mask, + siglip_feats, + siglip_freqs, + siglip_seqlens, + siglip_noise_mask, + omni_mode, + device, + ) + + # Main transformer layers + for layer_idx, layer in enumerate(self.layers): unified = gradient_checkpoint_forward( layer, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, - x=unified, - attn_mask=unified_attn_mask, - freqs_cis=unified_freqs_cis, - adaln_input=adaln_input, + x=unified, attn_mask=unified_mask, freqs_cis=unified_freqs, adaln_input=adaln_input, noise_mask=unified_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean ) - unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) - unified = list(unified.unbind(dim=0)) - x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + unified = ( + self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean + ) + if omni_mode + else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input) + ) - return x, {} + # Unpatchify + x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets) + + return x diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index f87254f..d119cbf 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -4,16 +4,18 @@ from typing import Union from tqdm import tqdm from einops import rearrange import numpy as np -from typing import Union, List, Optional, Tuple +from typing import Union, List, Optional, Tuple, Iterable from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward +from ..core.data.operators import ImageCropAndResize from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput from transformers import AutoTokenizer from ..models.z_image_text_encoder import ZImageTextEncoder from ..models.z_image_dit import ZImageDiT from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder +from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M class ZImagePipeline(BasePipeline): @@ -28,6 +30,7 @@ class ZImagePipeline(BasePipeline): self.dit: ZImageDiT = None self.vae_encoder: FluxVAEEncoder = None self.vae_decoder: FluxVAEDecoder = None + self.image_encoder: Siglip2ImageEncoder428M = None self.tokenizer: AutoTokenizer = None self.in_iteration_models = ("dit",) self.units = [ @@ -35,6 +38,9 @@ class ZImagePipeline(BasePipeline): ZImageUnit_PromptEmbedder(), ZImageUnit_NoiseInitializer(), ZImageUnit_InputImageEmbedder(), + ZImageUnit_EditImageAutoResize(), + ZImageUnit_EditImageEmbedderVAE(), + ZImageUnit_EditImageEmbedderSiglip(), ] self.model_fn = model_fn_z_image @@ -56,6 +62,7 @@ class ZImagePipeline(BasePipeline): pipe.dit = model_pool.fetch_model("z_image_dit") pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") + pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m") if tokenizer_config is not None: tokenizer_config.download_if_necessary() pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) @@ -75,6 +82,9 @@ class ZImagePipeline(BasePipeline): # Image input_image: Image.Image = None, denoising_strength: float = 1.0, + # Edit + edit_image: Image.Image = None, + edit_image_auto_resize: bool = True, # Shape height: int = 1024, width: int = 1024, @@ -83,11 +93,12 @@ class ZImagePipeline(BasePipeline): rand_device: str = "cpu", # Steps num_inference_steps: int = 8, + sigma_shift: float = None, # Progress bar progress_bar_cmd = tqdm, ): # Scheduler - self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) # Parameters inputs_posi = { @@ -102,6 +113,7 @@ class ZImagePipeline(BasePipeline): "height": height, "width": width, "seed": seed, "rand_device": rand_device, "num_inference_steps": num_inference_steps, + "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -143,12 +155,13 @@ class ZImageUnit_PromptEmbedder(PipelineUnit): def __init__(self): super().__init__( seperate_cfg=True, + input_params=("edit_image",), input_params_posi={"prompt": "prompt"}, input_params_nega={"prompt": "negative_prompt"}, output_params=("prompt_embeds",), onload_model_names=("text_encoder",) ) - + def encode_prompt( self, pipe, @@ -194,10 +207,81 @@ class ZImageUnit_PromptEmbedder(PipelineUnit): embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) return embeddings_list + + def encode_prompt_omni( + self, + pipe, + prompt: Union[str, List[str]], + edit_image=None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + if isinstance(prompt, str): + prompt = [prompt] - def process(self, pipe: ZImagePipeline, prompt): + if edit_image is None: + num_condition_images = 0 + elif isinstance(edit_image, list): + num_condition_images = len(edit_image) + else: + num_condition_images = 1 + + for i, prompt_item in enumerate(prompt): + if num_condition_images == 0: + prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] + elif num_condition_images > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) + prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|im_end|>"] + prompt[i] = prompt_list + + flattened_prompt = [] + prompt_list_lengths = [] + + for i in range(len(prompt)): + prompt_list_lengths.append(len(prompt[i])) + flattened_prompt.extend(prompt[i]) + + text_inputs = pipe.tokenizer( + flattened_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = pipe.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + start_idx = 0 + for i in range(len(prompt_list_lengths)): + batch_embeddings = [] + end_idx = start_idx + prompt_list_lengths[i] + for j in range(start_idx, end_idx): + batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) + embeddings_list.append(batch_embeddings) + start_idx = end_idx + + return embeddings_list + + def process(self, pipe: ZImagePipeline, prompt, edit_image): pipe.load_models_to_device(self.onload_model_names) - prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device) + if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None: + # Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods. + # We determine which encoding method to use based on the model architecture. + # If you are using two-stage split training, + # please use `--offload_models` instead of skipping the DiT model loading. + prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device) + else: + prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device) return {"prompt_embeds": prompt_embeds} @@ -234,24 +318,197 @@ class ZImageUnit_InputImageEmbedder(PipelineUnit): return {"latents": latents, "input_latents": input_latents} +class ZImageUnit_EditImageAutoResize(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image", "edit_image_auto_resize"), + output_params=("edit_image",), + ) + + def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize): + if edit_image is None: + return {} + if edit_image_auto_resize is None or not edit_image_auto_resize: + return {} + operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16) + edit_image = operator(edit_image) + return {"edit_image": edit_image} + + +class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image",), + output_params=("image_embeds",), + onload_model_names=("image_encoder",) + ) + + def process(self, pipe: ZImagePipeline, edit_image): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if not isinstance(edit_image, list): + edit_image = [edit_image] + image_emb = [] + for image_ in edit_image: + image_emb.append(pipe.image_encoder(image_, device=pipe.device)) + return {"image_embeds": image_emb} + + +class ZImageUnit_EditImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image",), + output_params=("image_latents",), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: ZImagePipeline, edit_image): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if not isinstance(edit_image, list): + edit_image = [edit_image] + image_latents = [] + for image_ in edit_image: + image_ = pipe.preprocess_image(image_) + image_latents.append(pipe.vae_encoder(image_)) + return {"image_latents": image_latents} + + def model_fn_z_image( dit: ZImageDiT, latents=None, timestep=None, prompt_embeds=None, + image_embeds=None, + image_latents=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, ): + # Due to the complex and verbose codebase of Z-Image, + # we are temporarily using this inelegant structure. + # We will refactor this part in the future (if time permits). + if dit.siglip_embedder is None: + return model_fn_z_image_turbo( + dit, + latents, + timestep, + prompt_embeds, + image_embeds, + image_latents, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + **kwargs, + ) latents = [rearrange(latents, "B C H W -> C B H W")] + if dit.siglip_embedder is not None: + if image_latents is not None: + image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents] + latents = [image_latents + latents] + image_noise_mask = [[0] * len(image_latents) + [1]] + else: + latents = [latents] + image_noise_mask = [[1]] + image_embeds = [image_embeds] + else: + image_noise_mask = None timestep = (1000 - timestep) / 1000 model_output = dit( latents, timestep, prompt_embeds, + siglip_feats=image_embeds, + image_noise_mask=image_noise_mask, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, - )[0][0] + )[0] model_output = -model_output model_output = rearrange(model_output, "C B H W -> B C H W") return model_output + + +def model_fn_z_image_turbo( + dit: ZImageDiT, + latents=None, + timestep=None, + prompt_embeds=None, + image_embeds=None, + image_latents=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + while isinstance(prompt_embeds, list): + prompt_embeds = prompt_embeds[0] + while isinstance(latents, list): + latents = latents[0] + while isinstance(image_embeds, list): + image_embeds = image_embeds[0] + + # Timestep + timestep = 1000 - timestep + t_noisy = dit.t_embedder(timestep) + t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000) + + # Patchify + latents = rearrange(latents, "B C H W -> C B H W") + x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds]) + x = x[0] + cap_feats = cap_feats[0] + + # Noise refine + x = dit.all_x_embedder["2-1"](x) + x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0)) + x = rearrange(x, "L C -> 1 L C") + x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C") + + for layer in dit.noise_refiner: + x = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=x, + attn_mask=None, + freqs_cis=x_freqs_cis, + adaln_input=t_noisy, + ) + + # Prompt refine + cap_feats = dit.cap_embedder(cap_feats) + cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device) + cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0)) + cap_feats = rearrange(cap_feats, "L C -> 1 L C") + cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C") + + for layer in dit.context_refiner: + cap_feats = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=cap_feats, + attn_mask=None, + freqs_cis=cap_freqs_cis, + ) + + # Unified + unified = torch.cat([x, cap_feats], dim=1) + unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1) + for layer in dit.layers: + unified = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=unified, + attn_mask=None, + freqs_cis=unified_freqs_cis, + adaln_input=t_noisy, + ) + + # Output + unified = dit.all_final_layer["2-1"](unified, t_noisy) + x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0] + x = rearrange(x, "C B H W -> B C H W") + x = -x + return x