diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index f97896c..7da7a9d 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -534,6 +534,32 @@ z_image_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers", "extra_kwargs": {"use_conv_attention": False}, }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors") + "model_hash": "aa3563718e5c3ecde3dfbb020ca61180", + "model_name": "z_image_dit", + "model_class": "diffsynth.models.z_image_dit.ZImageDiT", + "extra_kwargs": {"siglip_feat_dim": 1152}, + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors") + "model_hash": "89d48e420f45cff95115a9f3e698d44a", + "model_name": "siglip_vision_model_428m", + "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M", + }, + { + # Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors") + "model_hash": "1677708d40029ab380a95f6c731a57d7", + "model_name": "z_image_controlnet", + "model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet", + }, + { + # Example: ??? + "model_hash": "9510cb8cd1dd34ee0e4f111c24905510", + "model_name": "z_image_image2lora_style", + "model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel", + "extra_kwargs": {"compress_dim": 128}, + }, ] MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index 5f1b595..a1813fb 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -195,4 +195,19 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", }, + "diffsynth.models.z_image_controlnet.ZImageControlNet": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": { + "transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, } diff --git a/diffsynth/core/loader/config.py b/diffsynth/core/loader/config.py index 562675f..88b46a0 100644 --- a/diffsynth/core/loader/config.py +++ b/diffsynth/core/loader/config.py @@ -97,6 +97,7 @@ class ModelConfig: self.reset_local_model_path() if self.require_downloading(): self.download() + if self.path is None: if self.origin_file_pattern is None or self.origin_file_pattern == "": self.path = os.path.join(self.local_model_path, self.model_id) else: diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index fa355a1..e37f381 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -235,6 +235,7 @@ class BasePipeline(torch.nn.Module): alpha=1, hotload=None, state_dict=None, + verbose=1, ): if state_dict is None: if isinstance(lora_config, str): @@ -261,12 +262,13 @@ class BasePipeline(torch.nn.Module): updated_num += 1 module.lora_A_weights.append(lora[lora_a_name] * alpha) module.lora_B_weights.append(lora[lora_b_name]) - print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.") + if verbose >= 1: + print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.") else: lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha) - def clear_lora(self): + def clear_lora(self, verbose=1): cleared_num = 0 for name, module in self.named_modules(): if isinstance(module, AutoWrappedLinear): @@ -276,7 +278,8 @@ class BasePipeline(torch.nn.Module): module.lora_A_weights.clear() if hasattr(module, "lora_B_weights"): module.lora_B_weights.clear() - print(f"{cleared_num} LoRA layers are cleared.") + if verbose >= 1: + print(f"{cleared_num} LoRA layers are cleared.") def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None): @@ -304,8 +307,13 @@ class BasePipeline(torch.nn.Module): def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0) noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) if cfg_scale != 1.0: + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py index 10184f8..87df855 100644 --- a/diffsynth/models/siglip2_image_encoder.py +++ b/diffsynth/models/siglip2_image_encoder.py @@ -1,5 +1,5 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig -from transformers import SiglipImageProcessor +from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast import torch @@ -68,3 +68,65 @@ class Siglip2ImageEncoder(SiglipVisionTransformer): pooler_output = self.head(last_hidden_state) if self.use_head else None return pooler_output + + +class Siglip2ImageEncoder428M(Siglip2VisionModel): + def __init__(self): + config = Siglip2VisionConfig( + attention_dropout = 0.0, + dtype = "bfloat16", + hidden_act = "gelu_pytorch_tanh", + hidden_size = 1152, + intermediate_size = 4304, + layer_norm_eps = 1e-06, + model_type = "siglip2_vision_model", + num_attention_heads = 16, + num_channels = 3, + num_hidden_layers = 27, + num_patches = 256, + patch_size = 16, + transformers_version = "4.57.1" + ) + super().__init__(config) + self.processor = Siglip2ImageProcessorFast( + **{ + "data_format": "channels_first", + "default_to_square": True, + "device": None, + "disable_grouping": None, + "do_convert_rgb": None, + "do_normalize": True, + "do_pad": None, + "do_rescale": True, + "do_resize": True, + "image_mean": [ + 0.5, + 0.5, + 0.5 + ], + "image_processor_type": "Siglip2ImageProcessorFast", + "image_std": [ + 0.5, + 0.5, + 0.5 + ], + "input_data_format": None, + "max_num_patches": 256, + "pad_size": None, + "patch_size": 16, + "processor_class": "Siglip2Processor", + "resample": 2, + "rescale_factor": 0.00392156862745098, + "return_tensors": None, + } + ) + + def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device) + shape = siglip_inputs.spatial_shapes[0] + hidden_state = super().forward(**siglip_inputs).last_hidden_state + B, N, C = hidden_state.shape + hidden_state = hidden_state[:, : shape[0] * shape[1]] + hidden_state = hidden_state.view(shape[0], shape[1], C) + hidden_state = hidden_state.to(torch_dtype) + return hidden_state diff --git a/diffsynth/models/z_image_controlnet.py b/diffsynth/models/z_image_controlnet.py new file mode 100644 index 0000000..5105534 --- /dev/null +++ b/diffsynth/models/z_image_controlnet.py @@ -0,0 +1,154 @@ +from .z_image_dit import ZImageTransformerBlock +from ..core.gradient import gradient_checkpoint_forward +from torch.nn.utils.rnn import pad_sequence +import torch +from torch import nn + + +class ZImageControlTransformerBlock(ZImageTransformerBlock): + def __init__( + self, + layer_id: int = 1000, + dim: int = 3840, + n_heads: int = 30, + n_kv_heads: int = 30, + norm_eps: float = 1e-5, + qk_norm: bool = True, + modulation = True, + block_id = 0 + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(self.dim, self.dim) + self.after_proj = nn.Linear(self.dim, self.dim) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class ZImageControlNet(torch.nn.Module): + def __init__( + self, + control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28), + control_in_dim=33, + dim=3840, + n_refiner_layers=2, + ): + super().__init__() + self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places]) + self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)}) + self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)]) + self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14} + + def forward_layers( + self, + x, + cap_feats, + control_context, + control_context_item_seqlens, + kwargs, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + bsz = len(control_context) + # unified + cap_item_seqlens = [len(_) for _ in cap_feats] + control_context_unified = [] + for i in range(bsz): + control_context_len = control_context_item_seqlens[i] + cap_len = cap_item_seqlens[i] + control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]])) + c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + for layer in self.control_layers: + c = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + c=c, **new_kwargs + ) + + hints = torch.unbind(c)[:-1] + return hints + + def forward_refiner( + self, + dit, + x, + cap_feats, + control_context, + kwargs, + t=None, + patch_size=2, + f_patch_size=1, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + # embeddings + bsz = len(control_context) + device = control_context[0].device + ( + control_context, + control_context_size, + control_context_pos_ids, + control_context_inner_pad_mask, + ) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0)) + + # control_context embed & refine + control_context_item_seqlens = [len(_) for _ in control_context] + assert all(_ % 2 == 0 for _ in control_context_item_seqlens) + control_context_max_item_seqlen = max(control_context_item_seqlens) + + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + # Match t_embedder output dtype to control_context for layerwise casting compatibility + adaln_input = t.type_as(control_context) + control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device) + control_context = list(control_context.split(control_context_item_seqlens, dim=0)) + control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0) + control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(control_context_item_seqlens): + control_context_attn_mask[i, :seq_len] = 1 + c = control_context + + # arguments + new_kwargs = dict( + x=x, + attn_mask=control_context_attn_mask, + freqs_cis=control_context_freqs_cis, + adaln_input=adaln_input, + ) + new_kwargs.update(kwargs) + + for layer in self.control_noise_refiner: + c = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + c=c, **new_kwargs + ) + + hints = torch.unbind(c)[:-1] + control_context = torch.unbind(c)[-1] + + return hints, control_context, control_context_item_seqlens \ No newline at end of file diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index 7664fc5..9744ddb 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -13,6 +13,7 @@ from ..core.gradient import gradient_checkpoint_forward ADALN_EMBED_DIM = 256 SEQ_MULTI_OF = 32 +X_PAD_DIM = 64 class TimestepEmbedder(nn.Module): @@ -86,7 +87,7 @@ class Attention(torch.nn.Module): self.norm_q = RMSNorm(head_dim, eps=1e-5) self.norm_k = RMSNorm(head_dim, eps=1e-5) - def forward(self, hidden_states, freqs_cis): + def forward(self, hidden_states, freqs_cis, attention_mask): query = self.to_q(hidden_states) key = self.to_k(hidden_states) value = self.to_v(hidden_states) @@ -123,6 +124,7 @@ class Attention(torch.nn.Module): key, value, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + attn_mask=attention_mask, ) # Reshape back @@ -136,6 +138,20 @@ class Attention(torch.nn.Module): return output +def select_per_token( + value_noisy: torch.Tensor, + value_clean: torch.Tensor, + noise_mask: torch.Tensor, + seq_len: int, +) -> torch.Tensor: + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + return torch.where( + noise_mask_expanded == 1, + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), + value_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + + class ZImageTransformerBlock(nn.Module): def __init__( self, @@ -180,40 +196,53 @@ class ZImageTransformerBlock(nn.Module): attn_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, ): if self.modulation: - assert adaln_input is not None - scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) - gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() - scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens + mod_noisy = self.adaLN_modulation(adaln_noisy) + mod_clean = self.adaLN_modulation(adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) + else: + # Global modulation: same modulation for all tokens (avoid double select) + mod = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp # Attention block attn_out = self.attention( - self.attention_norm1(x) * scale_msa, - freqs_cis=freqs_cis, + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis ) x = x + gate_msa * self.attention_norm2(attn_out) # FFN block - x = x + gate_mlp * self.ffn_norm2( - self.feed_forward( - self.ffn_norm1(x) * scale_mlp, - ) - ) + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) else: # Attention block - attn_out = self.attention( - self.attention_norm1(x), - freqs_cis=freqs_cis, - ) + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) x = x + self.attention_norm2(attn_out) # FFN block - x = x + self.ffn_norm2( - self.feed_forward( - self.ffn_norm1(x), - ) - ) + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) return x @@ -229,9 +258,21 @@ class FinalLayer(nn.Module): nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), ) - def forward(self, x, c): - scale = 1.0 + self.adaLN_modulation(c) - x = self.norm_final(x) * scale.unsqueeze(1) + def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation + scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) + scale_clean = 1.0 + self.adaLN_modulation(c_clean) + scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len) + else: + # Original global modulation + assert c is not None, "Either c or (c_noisy, c_clean) must be provided" + scale = 1.0 + self.adaLN_modulation(c) + scale = scale.unsqueeze(1) + + x = self.norm_final(x) * scale x = self.linear(x) return x @@ -299,6 +340,7 @@ class ZImageDiT(nn.Module): t_scale=1000.0, axes_dims=[32, 48, 48], axes_lens=[1024, 512, 512], + siglip_feat_dim=None, ) -> None: super().__init__() self.in_channels = in_channels @@ -359,6 +401,32 @@ class ZImageDiT(nn.Module): nn.Linear(cap_feat_dim, dim, bias=True), ) + # Optional SigLIP components (for Omni variant) + self.siglip_feat_dim = siglip_feat_dim + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True) + ) + self.siglip_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 2000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) @@ -375,22 +443,57 @@ class ZImageDiT(nn.Module): self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: + def unpatchify( + self, + x: List[torch.Tensor], + size: List[Tuple], + patch_size = 2, + f_patch_size = 1, + x_pos_offsets: Optional[List[Tuple[int, int]]] = None, + ) -> List[torch.Tensor]: pH = pW = patch_size pF = f_patch_size bsz = len(x) assert len(size) == bsz - for i in range(bsz): - F, H, W = size[i] - ori_len = (F // pF) * (H // pH) * (W // pW) - # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" - x[i] = ( - x[i][:ori_len] - .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) - .permute(6, 0, 3, 1, 4, 2, 5) - .reshape(self.out_channels, F, H, W) - ) - return x + + if x_pos_offsets is not None: + # Omni: extract target image from unified sequence (cond_images + target) + result = [] + for i in range(bsz): + unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]] + cu_len = 0 + x_item = None + for j in range(len(size[i])): + if size[i][j] is None: + ori_len = 0 + pad_len = SEQ_MULTI_OF + cu_len += pad_len + ori_len + else: + F, H, W = size[i][j] + ori_len = (F // pF) * (H // pH) * (W // pW) + pad_len = (-ori_len) % SEQ_MULTI_OF + x_item = ( + unified_x[cu_len : cu_len + ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + cu_len += ori_len + pad_len + result.append(x_item) # Return only the last (target) image + return result + else: + # Original mode: simple unpatchify + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x @staticmethod def create_coordinate_grid(size, start=None, device=None): @@ -405,8 +508,8 @@ class ZImageDiT(nn.Module): self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], - patch_size: int, - f_patch_size: int, + patch_size: int = 2, + f_patch_size: int = 1, ): pH = pW = patch_size pF = f_patch_size @@ -490,90 +593,487 @@ class ZImageDiT(nn.Module): image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) all_image_out.append(image_padded_feat) + return all_image_out, all_cap_feats_out, { + "x_size": all_image_size, + "x_pos_ids": all_image_pos_ids, + "cap_pos_ids": all_cap_pos_ids, + "x_pad_mask": all_image_pad_mask, + "cap_pad_mask": all_cap_pad_mask + } + # ( + # all_img_out, + # all_cap_out, + # all_img_size, + # all_img_pos_ids, + # all_cap_pos_ids, + # all_img_pad_mask, + # all_cap_pad_mask, + # ) + + def patchify_controlnet( + self, + all_image: List[torch.Tensor], + patch_size: int = 2, + f_patch_size: int = 1, + cap_padding_len: int = None, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + + for i, image in enumerate(all_image): + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + return ( all_image_out, - all_cap_feats_out, all_image_size, all_image_pos_ids, - all_cap_pos_ids, all_image_pad_mask, - all_cap_pad_mask, ) + + def _prepare_sequence( + self, + feats: List[torch.Tensor], + pos_ids: List[torch.Tensor], + inner_pad_mask: List[torch.Tensor], + pad_token: torch.nn.Parameter, + noise_mask: Optional[List[List[int]]] = None, + device: torch.device = None, + ): + """Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask.""" + item_seqlens = [len(f) for f in feats] + max_seqlen = max(item_seqlens) + bsz = len(feats) + + # Pad token + feats_cat = torch.cat(feats, dim=0) + feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device) + feats = list(feats_cat.split(item_seqlens, dim=0)) + + # RoPE + freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0)) + + # Pad to batch + feats = pad_sequence(feats, batch_first=True, padding_value=0.0) + freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]] + + # Attention mask + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(item_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if noise_mask is not None: + noise_mask_tensor = pad_sequence( + [torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask], + batch_first=True, + padding_value=0, + )[:, : feats.shape[1]] + + return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor + + def _build_unified_sequence( + self, + x: torch.Tensor, + x_freqs: torch.Tensor, + x_seqlens: List[int], + x_noise_mask: Optional[List[List[int]]], + cap: torch.Tensor, + cap_freqs: torch.Tensor, + cap_seqlens: List[int], + cap_noise_mask: Optional[List[List[int]]], + siglip: Optional[torch.Tensor], + siglip_freqs: Optional[torch.Tensor], + siglip_seqlens: Optional[List[int]], + siglip_noise_mask: Optional[List[List[int]]], + omni_mode: bool, + device: torch.device, + ): + """Build unified sequence: x, cap, and optionally siglip. + Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip] + """ + bsz = len(x_seqlens) + unified = [] + unified_freqs = [] + unified_noise_mask = [] + + for i in range(bsz): + x_len, cap_len = x_seqlens[i], cap_seqlens[i] + + if omni_mode: + # Omni: [cap, x, siglip] + if siglip is not None and siglip_seqlens is not None: + sig_len = siglip_seqlens[i] + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]])) + unified_freqs.append( + torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]]) + ) + unified_noise_mask.append( + torch.tensor( + cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device + ) + ) + else: + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]])) + unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]])) + unified_noise_mask.append( + torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device) + ) + else: + # Basic: [x, cap] + unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]])) + unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]])) + + # Compute unified seqlens + if omni_mode: + if siglip is not None and siglip_seqlens is not None: + unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)] + + max_seqlen = max(unified_seqlens) + + # Pad to batch + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) + + # Attention mask + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if omni_mode: + noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[ + :, : unified.shape[1] + ] + + return unified, unified_freqs, attn_mask, noise_mask_tensor + + def _pad_with_ids( + self, + feat: torch.Tensor, + pos_grid_size: Tuple, + pos_start: Tuple, + device: torch.device, + noise_mask_val: Optional[int] = None, + ): + """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" + ori_len = len(feat) + pad_len = (-ori_len) % SEQ_MULTI_OF + total_len = ori_len + pad_len + + # Pos IDs + ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) + if pad_len > 0: + pad_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(pad_len, 1) + ) + pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) + padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) + pad_mask = torch.cat( + [ + torch.zeros(ori_len, dtype=torch.bool, device=device), + torch.ones(pad_len, dtype=torch.bool, device=device), + ] + ) + else: + pos_ids = ori_pos_ids + padded_feat = feat + pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) + + noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level + return padded_feat, pos_ids, pad_mask, total_len, noise_mask + + def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): + """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" + pH, pW, pF = patch_size, patch_size, f_patch_size + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + return image, (F, H, W), (F_tokens, H_tokens, W_tokens) + + def patchify_and_embed_omni( + self, + all_x: List[List[torch.Tensor]], + all_cap_feats: List[List[torch.Tensor]], + all_siglip_feats: List[List[torch.Tensor]], + patch_size: int = 2, + f_patch_size: int = 1, + images_noise_mask: List[List[int]] = None, + ): + """Patchify for omni mode: multiple images per batch item with noise masks.""" + bsz = len(all_x) + device = all_x[0][-1].device + dtype = all_x[0][-1].dtype + + all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], [] + all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], [] + + for i in range(bsz): + num_images = len(all_x[i]) + cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], [] + cap_end_pos = [] + cap_cu_len = 1 + + # Process captions + for j, cap_item in enumerate(all_cap_feats[i]): + noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1 + cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids( + cap_item, + (len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1), + (cap_cu_len, 0, 0), + device, + noise_val, + ) + cap_feats_list.append(cap_out) + cap_pos_list.append(cap_pos) + cap_mask_list.append(cap_mask) + cap_lens.append(cap_len) + cap_noise.extend(cap_nm) + cap_cu_len += len(cap_item) + cap_end_pos.append(cap_cu_len) + cap_cu_len += 2 # for image vae and siglip tokens + + all_cap_out.append(torch.cat(cap_feats_list, dim=0)) + all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0)) + all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0)) + all_cap_len.append(cap_lens) + all_cap_noise_mask.append(cap_noise) + + # Process images + x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], [] + for j, x_item in enumerate(all_x[i]): + noise_val = images_noise_mask[i][j] + if x_item is not None: + x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size) + x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids( + x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val + ) + x_size.append(size) + else: + x_len = SEQ_MULTI_OF + x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device) + x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1) + x_mask = torch.ones(x_len, dtype=torch.bool, device=device) + x_nm = [noise_val] * x_len + x_size.append(None) + x_feats_list.append(x_out) + x_pos_list.append(x_pos) + x_mask_list.append(x_mask) + x_lens.append(x_len) + x_noise.extend(x_nm) + + all_x_out.append(torch.cat(x_feats_list, dim=0)) + all_x_pos_ids.append(torch.cat(x_pos_list, dim=0)) + all_x_pad_mask.append(torch.cat(x_mask_list, dim=0)) + all_x_size.append(x_size) + all_x_len.append(x_lens) + all_x_noise_mask.append(x_noise) + + # Process siglip + if all_siglip_feats[i] is None: + all_sig_len.append([0] * num_images) + all_sig_out.append(None) + else: + sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], [] + for j, sig_item in enumerate(all_siglip_feats[i]): + noise_val = images_noise_mask[i][j] + if sig_item is not None: + sig_H, sig_W, sig_C = sig_item.size() + sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C) + sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids( + sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val + ) + # Scale position IDs to match x resolution + if x_size[j] is not None: + sig_pos = sig_pos.float() + sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1) + sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1) + sig_pos = sig_pos.to(torch.int32) + else: + sig_len = SEQ_MULTI_OF + sig_out = torch.zeros((sig_len, self.siglip_feat_dim), dtype=dtype, device=device) + sig_pos = ( + self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1) + ) + sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device) + sig_nm = [noise_val] * sig_len + sig_feats_list.append(sig_out) + sig_pos_list.append(sig_pos) + sig_mask_list.append(sig_mask) + sig_lens.append(sig_len) + sig_noise.extend(sig_nm) + + all_sig_out.append(torch.cat(sig_feats_list, dim=0)) + all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0)) + all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0)) + all_sig_len.append(sig_lens) + all_sig_noise_mask.append(sig_noise) + + # Compute x position offsets + all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)] + + return ( + all_x_out, + all_cap_out, + all_sig_out, + all_x_size, + all_x_pos_ids, + all_cap_pos_ids, + all_sig_pos_ids, + all_x_pad_mask, + all_cap_pad_mask, + all_sig_pad_mask, + all_x_pos_offsets, + all_x_noise_mask, + all_cap_noise_mask, + all_sig_noise_mask, + ) + return all_x_out, all_cap_out, all_sig_out, { + "x_size": x_size, + "x_pos_ids": all_x_pos_ids, + "cap_pos_ids": all_cap_pos_ids, + "sig_pos_ids": all_sig_pos_ids, + "x_pad_mask": all_x_pad_mask, + "cap_pad_mask": all_cap_pad_mask, + "sig_pad_mask": all_sig_pad_mask, + "x_pos_offsets": all_x_pos_offsets, + "x_noise_mask": all_x_noise_mask, + "cap_noise_mask": all_cap_noise_mask, + "sig_noise_mask": all_sig_noise_mask, + } def forward( self, x: List[torch.Tensor], t, cap_feats: List[torch.Tensor], + siglip_feats = None, + image_noise_mask = None, patch_size=2, f_patch_size=1, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, ): - assert patch_size in self.all_patch_size - assert f_patch_size in self.all_f_patch_size + assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size + omni_mode = isinstance(x[0], list) + device = x[0][-1].device if omni_mode else x[0].device - bsz = len(x) - device = x[0].device - t = t * self.t_scale - t = self.t_embedder(t) + if omni_mode: + # Dual embeddings: noisy (t) and clean (t=1) + t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1]) + t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1]) + adaln_input = None + else: + # Single embedding for all tokens + adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0]) + t_noisy = t_clean = None - adaln_input = t - - ( - x, - cap_feats, - x_size, - x_pos_ids, - cap_pos_ids, - x_inner_pad_mask, - cap_inner_pad_mask, - ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + # Patchify + if omni_mode: + ( + x, + cap_feats, + siglip_feats, + x_size, + x_pos_ids, + cap_pos_ids, + siglip_pos_ids, + x_pad_mask, + cap_pad_mask, + siglip_pad_mask, + x_pos_offsets, + x_noise_mask, + cap_noise_mask, + siglip_noise_mask, + ) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask) + else: + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_pad_mask, + cap_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None # x embed & refine - x_item_seqlens = [len(_) for _ in x] - assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) - x_max_item_seqlen = max(x_item_seqlens) - - x = torch.cat(x, dim=0) - x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) - x[torch.cat(x_inner_pad_mask)] = self.x_pad_token.to(dtype=x.dtype, device=x.device) - x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) - - x = pad_sequence(x, batch_first=True, padding_value=0.0) - x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) - x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(x_item_seqlens): - x_attn_mask[i, :seq_len] = 1 + x_seqlens = [len(xi) for xi in x] + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed + x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence( + list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device + ) for layer in self.noise_refiner: x = gradient_checkpoint_forward( layer, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, - x=x, - attn_mask=x_attn_mask, - freqs_cis=x_freqs_cis, - adaln_input=adaln_input, + x=x, attn_mask=x_mask, freqs_cis=x_freqs, adaln_input=adaln_input, noise_mask=x_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean, ) - # cap embed & refine - cap_item_seqlens = [len(_) for _ in cap_feats] - assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) - cap_max_item_seqlen = max(cap_item_seqlens) - - cap_feats = torch.cat(cap_feats, dim=0) - cap_feats = self.cap_embedder(cap_feats) - cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token.to(dtype=x.dtype, device=x.device) - cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) - - cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) - cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) - cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(cap_item_seqlens): - cap_attn_mask[i, :seq_len] = 1 + # Cap embed & refine + cap_seqlens = [len(ci) for ci in cap_feats] + cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed + cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence( + list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device + ) for layer in self.context_refiner: cap_feats = gradient_checkpoint_forward( @@ -581,41 +1081,68 @@ class ZImageDiT(nn.Module): use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, x=cap_feats, - attn_mask=cap_attn_mask, - freqs_cis=cap_freqs_cis, + attn_mask=cap_mask, + freqs_cis=cap_freqs, ) - # unified - unified = [] - unified_freqs_cis = [] - for i in range(bsz): - x_len = x_item_seqlens[i] - cap_len = cap_item_seqlens[i] - unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) - unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) - unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] - assert unified_item_seqlens == [len(_) for _ in unified] - unified_max_item_seqlen = max(unified_item_seqlens) + # Siglip embed & refine + siglip_seqlens = siglip_freqs = None + if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None: + siglip_seqlens = [len(si) for si in siglip_feats] + siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed + siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence( + list(siglip_feats.split(siglip_seqlens, dim=0)), + siglip_pos_ids, + siglip_pad_mask, + self.siglip_pad_token, + None, + device, + ) - unified = pad_sequence(unified, batch_first=True, padding_value=0.0) - unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) - unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(unified_item_seqlens): - unified_attn_mask[i, :seq_len] = 1 + for layer in self.siglip_refiner: + siglip_feats = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=siglip_feats, attn_mask=siglip_mask, freqs_cis=siglip_freqs, + ) - for layer in self.layers: + # Unified sequence + unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence( + x, + x_freqs, + x_seqlens, + x_noise_mask, + cap_feats, + cap_freqs, + cap_seqlens, + cap_noise_mask, + siglip_feats, + siglip_freqs, + siglip_seqlens, + siglip_noise_mask, + omni_mode, + device, + ) + + # Main transformer layers + for layer_idx, layer in enumerate(self.layers): unified = gradient_checkpoint_forward( layer, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, - x=unified, - attn_mask=unified_attn_mask, - freqs_cis=unified_freqs_cis, - adaln_input=adaln_input, + x=unified, attn_mask=unified_mask, freqs_cis=unified_freqs, adaln_input=adaln_input, noise_mask=unified_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean ) - unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) - unified = list(unified.unbind(dim=0)) - x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + unified = ( + self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean + ) + if omni_mode + else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input) + ) - return x, {} + # Unpatchify + x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets) + + return x diff --git a/diffsynth/models/z_image_image2lora.py b/diffsynth/models/z_image_image2lora.py new file mode 100644 index 0000000..757f3f6 --- /dev/null +++ b/diffsynth/models/z_image_image2lora.py @@ -0,0 +1,189 @@ +import torch +from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP + + +class LoRATrainerBlock(torch.nn.Module): + def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"): + super().__init__() + self.prefix = prefix + self.lora_patterns = lora_patterns + self.block_id = block_id + self.layers = [] + for name, lora_a_dim, lora_b_dim in self.lora_patterns: + self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank)) + self.layers = torch.nn.ModuleList(self.layers) + if use_residual: + self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim) + else: + self.proj_residual = None + + def forward(self, x, residual=None): + lora = {} + if self.proj_residual is not None: residual = self.proj_residual(residual) + for lora_pattern, layer in zip(self.lora_patterns, self.layers): + name = lora_pattern[0] + lora_a, lora_b = layer(x, residual=residual) + lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a + lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b + return lora + + +class ZImageImage2LoRAComponent(torch.nn.Module): + def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = lora_patterns + self.num_blocks = num_blocks + self.blocks = [] + for lora_patterns in self.lora_patterns: + for block_id in range(self.num_blocks): + self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix)) + self.blocks = torch.nn.ModuleList(self.blocks) + self.residual_scale = 0.05 + self.use_residual = use_residual + + def forward(self, x, residual=None): + if residual is not None: + if self.use_residual: + residual = residual * self.residual_scale + else: + residual = None + lora = {} + for block in self.blocks: + lora.update(block(x, residual)) + return lora + + +class ZImageImage2LoRAModel(torch.nn.Module): + def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + lora_patterns = [ + [ + ("attention.to_q", 3840, 3840), + ("attention.to_k", 3840, 3840), + ("attention.to_v", 3840, 3840), + ("attention.to_out.0", 3840, 3840), + ], + [ + ("feed_forward.w1", 3840, 10240), + ("feed_forward.w2", 10240, 3840), + ("feed_forward.w3", 3840, 10240), + ], + ] + config = { + "lora_patterns": lora_patterns, + "use_residual": use_residual, + "compress_dim": compress_dim, + "rank": rank, + "residual_length": residual_length, + "residual_mid_dim": residual_mid_dim, + } + self.layers_lora = ZImageImage2LoRAComponent( + prefix="layers", + num_blocks=30, + **config, + ) + self.context_refiner_lora = ZImageImage2LoRAComponent( + prefix="context_refiner", + num_blocks=2, + **config, + ) + self.noise_refiner_lora = ZImageImage2LoRAComponent( + prefix="noise_refiner", + num_blocks=2, + **config, + ) + + def forward(self, x, residual=None): + lora = {} + lora.update(self.layers_lora(x, residual=residual)) + lora.update(self.context_refiner_lora(x, residual=residual)) + lora.update(self.noise_refiner_lora(x, residual=residual)) + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if ".proj_a." in name: + state_dict[name] = state_dict[name] * 0.3 + elif ".proj_b.proj_out." in name: + state_dict[name] = state_dict[name] * 0 + elif ".proj_residual.proj_out." in name: + state_dict[name] = state_dict[name] * 0.3 + self.load_state_dict(state_dict) + + +class ImageEmb2LoRAWeightCompressed(torch.nn.Module): + def __init__(self, in_dim, out_dim, emb_dim, rank): + super().__init__() + self.lora_a = torch.nn.Parameter(torch.randn((rank, in_dim))) + self.lora_b = torch.nn.Parameter(torch.randn((out_dim, rank))) + self.proj = torch.nn.Linear(emb_dim, rank * rank, bias=True) + self.rank = rank + + def forward(self, x): + x = self.proj(x).view(self.rank, self.rank) + lora_a = x @ self.lora_a + lora_b = self.lora_b + return lora_a, lora_b + + +class ZImageImage2LoRAModelCompressed(torch.nn.Module): + def __init__(self, emb_dim=1536+4096, rank=32): + super().__init__() + target_layers = [ + ("attention.to_q", 3840, 3840), + ("attention.to_k", 3840, 3840), + ("attention.to_v", 3840, 3840), + ("attention.to_out.0", 3840, 3840), + ("feed_forward.w1", 3840, 10240), + ("feed_forward.w2", 10240, 3840), + ("feed_forward.w3", 3840, 10240), + ] + self.lora_patterns = [ + { + "prefix": "layers", + "num_layers": 30, + "target_layers": target_layers, + }, + { + "prefix": "context_refiner", + "num_layers": 2, + "target_layers": target_layers, + }, + { + "prefix": "noise_refiner", + "num_layers": 2, + "target_layers": target_layers, + }, + ] + module_dict = {} + for lora_pattern in self.lora_patterns: + prefix, num_layers, target_layers = lora_pattern["prefix"], lora_pattern["num_layers"], lora_pattern["target_layers"] + for layer_id in range(num_layers): + for layer_name, in_dim, out_dim in target_layers: + name = f"{prefix}.{layer_id}.{layer_name}".replace(".", "___") + model = ImageEmb2LoRAWeightCompressed(in_dim, out_dim, emb_dim, rank) + module_dict[name] = model + self.module_dict = torch.nn.ModuleDict(module_dict) + + def forward(self, x, residual=None): + lora = {} + for name, module in self.module_dict.items(): + name = name.replace("___", ".") + name_a, name_b = f"{name}.lora_A.default.weight", f"{name}.lora_B.default.weight" + lora_a, lora_b = module(x) + lora[name_a] = lora_a + lora[name_b] = lora_b + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if "lora_b" in name: + state_dict[name] = state_dict[name] * 0 + elif "lora_a" in name: + state_dict[name] = state_dict[name] * 0.2 + elif "proj.weight" in name: + print(name) + state_dict[name] = state_dict[name] * 0.2 + self.load_state_dict(state_dict) diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index f87254f..9ba182a 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -4,16 +4,23 @@ from typing import Union from tqdm import tqdm from einops import rearrange import numpy as np -from typing import Union, List, Optional, Tuple +from typing import Union, List, Optional, Tuple, Iterable, Dict from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward +from ..core.data.operators import ImageCropAndResize from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora import merge_lora from transformers import AutoTokenizer from ..models.z_image_text_encoder import ZImageTextEncoder from ..models.z_image_dit import ZImageDiT from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder +from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M +from ..models.z_image_controlnet import ZImageControlNet +from ..models.siglip2_image_encoder import Siglip2ImageEncoder +from ..models.dinov3_image_encoder import DINOv3ImageEncoder +from ..models.z_image_image2lora import ZImageImage2LoRAModel class ZImagePipeline(BasePipeline): @@ -28,13 +35,22 @@ class ZImagePipeline(BasePipeline): self.dit: ZImageDiT = None self.vae_encoder: FluxVAEEncoder = None self.vae_decoder: FluxVAEDecoder = None + self.image_encoder: Siglip2ImageEncoder428M = None + self.controlnet: ZImageControlNet = None + self.siglip2_image_encoder: Siglip2ImageEncoder = None + self.dinov3_image_encoder: DINOv3ImageEncoder = None + self.image2lora_style: ZImageImage2LoRAModel = None self.tokenizer: AutoTokenizer = None - self.in_iteration_models = ("dit",) + self.in_iteration_models = ("dit", "controlnet") self.units = [ ZImageUnit_ShapeChecker(), ZImageUnit_PromptEmbedder(), ZImageUnit_NoiseInitializer(), ZImageUnit_InputImageEmbedder(), + ZImageUnit_EditImageAutoResize(), + ZImageUnit_EditImageEmbedderVAE(), + ZImageUnit_EditImageEmbedderSiglip(), + ZImageUnit_PAIControlNet(), ] self.model_fn = model_fn_z_image @@ -56,6 +72,11 @@ class ZImagePipeline(BasePipeline): pipe.dit = model_pool.fetch_model("z_image_dit") pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") + pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m") + pipe.controlnet = model_pool.fetch_model("z_image_controlnet") + pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder") + pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder") + pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style") if tokenizer_config is not None: tokenizer_config.download_if_necessary() pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) @@ -75,6 +96,9 @@ class ZImagePipeline(BasePipeline): # Image input_image: Image.Image = None, denoising_strength: float = 1.0, + # Edit + edit_image: Image.Image = None, + edit_image_auto_resize: bool = True, # Shape height: int = 1024, width: int = 1024, @@ -83,11 +107,17 @@ class ZImagePipeline(BasePipeline): rand_device: str = "cpu", # Steps num_inference_steps: int = 8, + sigma_shift: float = None, + # ControlNet + controlnet_inputs: List[ControlNetInput] = None, + # Image to LoRA + image2lora_images: List[Image.Image] = None, + positive_only_lora: Dict[str, torch.Tensor] = None, # Progress bar progress_bar_cmd = tqdm, ): # Scheduler - self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) # Parameters inputs_posi = { @@ -102,6 +132,9 @@ class ZImagePipeline(BasePipeline): "height": height, "width": width, "seed": seed, "rand_device": rand_device, "num_inference_steps": num_inference_steps, + "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, + "controlnet_inputs": controlnet_inputs, + "image2lora_images": image2lora_images, "positive_only_lora": positive_only_lora, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -143,12 +176,13 @@ class ZImageUnit_PromptEmbedder(PipelineUnit): def __init__(self): super().__init__( seperate_cfg=True, + input_params=("edit_image",), input_params_posi={"prompt": "prompt"}, input_params_nega={"prompt": "negative_prompt"}, output_params=("prompt_embeds",), onload_model_names=("text_encoder",) ) - + def encode_prompt( self, pipe, @@ -194,10 +228,81 @@ class ZImageUnit_PromptEmbedder(PipelineUnit): embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) return embeddings_list + + def encode_prompt_omni( + self, + pipe, + prompt: Union[str, List[str]], + edit_image=None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + if isinstance(prompt, str): + prompt = [prompt] - def process(self, pipe: ZImagePipeline, prompt): + if edit_image is None: + num_condition_images = 0 + elif isinstance(edit_image, list): + num_condition_images = len(edit_image) + else: + num_condition_images = 1 + + for i, prompt_item in enumerate(prompt): + if num_condition_images == 0: + prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] + elif num_condition_images > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) + prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|im_end|>"] + prompt[i] = prompt_list + + flattened_prompt = [] + prompt_list_lengths = [] + + for i in range(len(prompt)): + prompt_list_lengths.append(len(prompt[i])) + flattened_prompt.extend(prompt[i]) + + text_inputs = pipe.tokenizer( + flattened_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = pipe.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + start_idx = 0 + for i in range(len(prompt_list_lengths)): + batch_embeddings = [] + end_idx = start_idx + prompt_list_lengths[i] + for j in range(start_idx, end_idx): + batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) + embeddings_list.append(batch_embeddings) + start_idx = end_idx + + return embeddings_list + + def process(self, pipe: ZImagePipeline, prompt, edit_image): pipe.load_models_to_device(self.onload_model_names) - prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device) + if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None: + # Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods. + # We determine which encoding method to use based on the model architecture. + # If you are using two-stage split training, + # please use `--offload_models` instead of skipping the DiT model loading. + prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device) + else: + prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device) return {"prompt_embeds": prompt_embeds} @@ -234,24 +339,330 @@ class ZImageUnit_InputImageEmbedder(PipelineUnit): return {"latents": latents, "input_latents": input_latents} +class ZImageUnit_EditImageAutoResize(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image", "edit_image_auto_resize"), + output_params=("edit_image",), + ) + + def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize): + if edit_image is None: + return {} + if edit_image_auto_resize is None or not edit_image_auto_resize: + return {} + operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16) + if not isinstance(edit_image, list): + edit_image = [edit_image] + edit_image = [operator(i) for i in edit_image] + return {"edit_image": edit_image} + + +class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image",), + output_params=("image_embeds",), + onload_model_names=("image_encoder",) + ) + + def process(self, pipe: ZImagePipeline, edit_image): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if not isinstance(edit_image, list): + edit_image = [edit_image] + image_emb = [] + for image_ in edit_image: + image_emb.append(pipe.image_encoder(image_, device=pipe.device)) + return {"image_embeds": image_emb} + + +class ZImageUnit_EditImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image",), + output_params=("image_latents",), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: ZImagePipeline, edit_image): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if not isinstance(edit_image, list): + edit_image = [edit_image] + image_latents = [] + for image_ in edit_image: + image_ = pipe.preprocess_image(image_) + image_latents.append(pipe.vae_encoder(image_)) + return {"image_latents": image_latents} + + +class ZImageUnit_PAIControlNet(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("controlnet_inputs", "height", "width"), + output_params=("control_context", "control_scale"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width): + if controlnet_inputs is None: + return {} + if len(controlnet_inputs) != 1: + print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.") + controlnet_input = controlnet_inputs[0] + pipe.load_models_to_device(self.onload_model_names) + + control_image = controlnet_input.image + if control_image is not None: + control_image = pipe.preprocess_image(control_image) + control_latents = pipe.vae_encoder(control_image) + else: + control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1 + + inpaint_mask = controlnet_input.inpaint_mask + if inpaint_mask is not None: + inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1) + inpaint_image = controlnet_input.inpaint_image + inpaint_image = pipe.preprocess_image(inpaint_image) + inpaint_image = inpaint_image * (inpaint_mask < 0.5) + inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1] + else: + inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) + inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device) + inpaint_latent = pipe.vae_encoder(inpaint_image) + + control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1) + control_context = rearrange(control_context, "B C H W -> B C 1 H W") + return {"control_context": control_context, "control_scale": controlnet_input.scale} + + def model_fn_z_image( dit: ZImageDiT, + controlnet: ZImageControlNet = None, latents=None, timestep=None, prompt_embeds=None, + image_embeds=None, + image_latents=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, ): + # Due to the complex and verbose codebase of Z-Image, + # we are temporarily using this inelegant structure. + # We will refactor this part in the future (if time permits). + if dit.siglip_embedder is None: + return model_fn_z_image_turbo( + dit, + controlnet=controlnet, + latents=latents, + timestep=timestep, + prompt_embeds=prompt_embeds, + image_embeds=image_embeds, + image_latents=image_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + **kwargs, + ) latents = [rearrange(latents, "B C H W -> C B H W")] + if dit.siglip_embedder is not None: + if image_latents is not None: + image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents] + latents = [image_latents + latents] + image_noise_mask = [[0] * len(image_latents) + [1]] + else: + latents = [latents] + image_noise_mask = [[1]] + image_embeds = [image_embeds] + else: + image_noise_mask = None timestep = (1000 - timestep) / 1000 model_output = dit( latents, timestep, prompt_embeds, + siglip_feats=image_embeds, + image_noise_mask=image_noise_mask, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, - )[0][0] + )[0] model_output = -model_output model_output = rearrange(model_output, "C B H W -> B C H W") return model_output + + +class ZImageUnit_Image2LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_images",), + output_params=("image2lora_x",), + onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",), + ) + from ..core.data.operators import ImageCropAndResize + self.processor_highres = ImageCropAndResize(height=1024, width=1024) + + def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["siglip2_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["dinov3_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]): + if images is None: + return {} + if not isinstance(images, list): + images = [images] + embs_siglip2 = self.encode_images_using_siglip2(pipe, images) + embs_dinov3 = self.encode_images_using_dinov3(pipe, images) + x = torch.concat([embs_siglip2, embs_dinov3], dim=-1) + return x + + def process(self, pipe: ZImagePipeline, image2lora_images): + if image2lora_images is None: + return {} + x = self.encode_images(pipe, image2lora_images) + return {"image2lora_x": x} + + +class ZImageUnit_Image2LoRADecode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_x",), + output_params=("lora",), + onload_model_names=("image2lora_style",), + ) + + def process(self, pipe: ZImagePipeline, image2lora_x): + if image2lora_x is None: + return {} + loras = [] + if pipe.image2lora_style is not None: + pipe.load_models_to_device(["image2lora_style"]) + for x in image2lora_x: + loras.append(pipe.image2lora_style(x=x, residual=None)) + lora = merge_lora(loras, alpha=1 / len(image2lora_x)) + return {"lora": lora} + + +def model_fn_z_image_turbo( + dit: ZImageDiT, + controlnet: ZImageControlNet = None, + latents=None, + timestep=None, + prompt_embeds=None, + image_embeds=None, + image_latents=None, + control_context=None, + control_scale=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + while isinstance(prompt_embeds, list): + prompt_embeds = prompt_embeds[0] + while isinstance(latents, list): + latents = latents[0] + while isinstance(image_embeds, list): + image_embeds = image_embeds[0] + + # Timestep + timestep = 1000 - timestep + t_noisy = dit.t_embedder(timestep) + t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000) + + # Patchify + latents = rearrange(latents, "B C H W -> C B H W") + x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds]) + x = x[0] + cap_feats = cap_feats[0] + + # Noise refine + x = dit.all_x_embedder["2-1"](x) + x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device) + x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0)) + x = rearrange(x, "L C -> 1 L C") + x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C") + + if control_context is not None: + kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy) + refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner( + dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1, + use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + for layer_id, layer in enumerate(dit.noise_refiner): + x = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=x, + attn_mask=None, + freqs_cis=x_freqs_cis, + adaln_input=t_noisy, + ) + if control_context is not None: + x = x + refiner_hints[layer_id] * control_scale + + # Prompt refine + cap_feats = dit.cap_embedder(cap_feats) + cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device) + cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0)) + cap_feats = rearrange(cap_feats, "L C -> 1 L C") + cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C") + + for layer in dit.context_refiner: + cap_feats = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=cap_feats, + attn_mask=None, + freqs_cis=cap_freqs_cis, + ) + + # Unified + unified = torch.cat([x, cap_feats], dim=1) + unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1) + + if control_context is not None: + kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy) + hints = controlnet.forward_layers( + unified, cap_feats, control_context, control_context_item_seqlens, kwargs, + use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + for layer_id, layer in enumerate(dit.layers): + unified = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=unified, + attn_mask=None, + freqs_cis=unified_freqs_cis, + adaln_input=t_noisy, + ) + if control_context is not None: + if layer_id in controlnet.control_layers_mapping: + unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale + + # Output + unified = dit.all_final_layer["2-1"](unified, t_noisy) + x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0] + x = rearrange(x, "C B H W -> B C H W") + x = -x + return x diff --git a/diffsynth/utils/controlnet/controlnet_input.py b/diffsynth/utils/controlnet/controlnet_input.py index 1a2949b..a79064b 100644 --- a/diffsynth/utils/controlnet/controlnet_input.py +++ b/diffsynth/utils/controlnet/controlnet_input.py @@ -9,5 +9,6 @@ class ControlNetInput: start: float = 1.0 end: float = 0.0 image: Image.Image = None + inpaint_image: Image.Image = None inpaint_mask: Image.Image = None processor_id: str = None diff --git a/examples/dev_tools/unit_test.py b/examples/dev_tools/unit_test.py index 364af47..200ced8 100644 --- a/examples/dev_tools/unit_test.py +++ b/examples/dev_tools/unit_test.py @@ -108,7 +108,14 @@ def test_flux(): run_inference("examples/flux/model_training/validate_lora") +def test_z_image(): + run_inference("examples/z_image/model_inference") + run_inference("examples/z_image/model_inference_low_vram") + run_train_multi_GPU("examples/z_image/model_training/full") + run_inference("examples/z_image/model_training/validate_full") + run_train_single_GPU("examples/z_image/model_training/lora") + run_inference("examples/z_image/model_training/validate_lora") + + if __name__ == "__main__": - test_qwen_image() - test_flux() - test_wan() + test_z_image() diff --git a/examples/z_image/model_inference/Z-Image-Omni-Base-i2L.py b/examples/z_image/model_inference/Z-Image-Omni-Base-i2L.py new file mode 100644 index 0000000..10d37ad --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-Omni-Base-i2L.py @@ -0,0 +1,62 @@ +from diffsynth.pipelines.z_image import ( + ZImagePipeline, ModelConfig, + ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode +) +from modelscope import snapshot_download +from safetensors.torch import save_file +import torch +from PIL import Image + +# Use `vram_config` to enable LoRA hot-loading +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cuda", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +# Load models +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", origin_file_pattern="model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Load images +snapshot_download( + model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", + allow_file_pattern="assets/style/*", + local_dir="data/style_input" +) +images = [Image.open(f"data/style_input/assets/style/1/{i}.jpg") for i in range(6)] + +# Image to LoRA +with torch.no_grad(): + embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] +save_file(lora, "lora.safetensors") + +# Generate images +prompt = "a cat" +negative_prompt = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符" +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=0, cfg_scale=7, num_inference_steps=50, + positive_only_lora=lora, + sigma_shift=8 +) +image.save("image.jpg") diff --git a/examples/z_image/model_inference/Z-Image-Omni-Base.py b/examples/z_image/model_inference/Z-Image-Omni-Base.py new file mode 100644 index 0000000..b1d2217 --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-Omni-Base.py @@ -0,0 +1,24 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4) +image.save("image_Z-Image-Omni-Base.jpg") + +image = Image.open("image_Z-Image-Omni-Base.jpg") +prompt = "Change the women's clothes to white cheongsam, keep other content unchanged" +image = pipe(prompt=prompt, edit_image=image, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4) +image.save("image_edit_Z-Image-Omni-Base.jpg") diff --git a/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py b/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py new file mode 100644 index 0000000..21b387e --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern="data/examples/upscale/low_res.png" +) +controlnet_image = Image.open("data/examples/upscale/low_res.png").resize((1024, 1024)) +prompt = "这是一张充满都市气息的户外人物肖像照片。画面中是一位年轻男性,他展现出时尚而自信的形象。人物拥有精心打理的短发发型,两侧修剪得较短,顶部保留一定长度,呈现出流行的Undercut造型。他佩戴着一副时尚的浅色墨镜或透明镜框眼镜,为整体造型增添了潮流感。脸上洋溢着温和友善的笑容,神情放松自然,给人以阳光开朗的印象。他身穿一件经典的牛仔外套,这件单品永不过时,展现出休闲又有型的穿衣风格。牛仔外套的蓝色调与整体氛围十分协调,领口处隐约可见内搭的衣物。照片的背景是典型的城市街景,可以看到模糊的建筑物、街道和行人,营造出繁华都市的氛围。背景经过了恰当的虚化处理,使人物主体更加突出。光线明亮而柔和,可能是白天的自然光,为照片带来清新通透的视觉效果。整张照片构图专业,景深控制得当,完美捕捉了一个现代都市年轻人充满活力和自信的瞬间,展现出积极向上的生活态度。" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_tile.jpg") diff --git a/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py b/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py new file mode 100644 index 0000000..54adbea --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py @@ -0,0 +1,40 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Control +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="depth/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024)) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_control.jpg") + +# Inpaint +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="inpaint/*.jpg" +) +inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) +prompt = "一只戴着墨镜的猫" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)]) +image.save("image_inpaint.jpg") diff --git a/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py b/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py new file mode 100644 index 0000000..2f872d0 --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py @@ -0,0 +1,46 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Control +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="depth/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024)) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe( + prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)], + num_inference_steps=30, +) +image.save("image_control.jpg") + +# Inpaint +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="inpaint/*.jpg" +) +inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) +prompt = "一只戴着墨镜的猫" +image = pipe( + prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)], + num_inference_steps=30, +) +image.save("image_inpaint.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py new file mode 100644 index 0000000..7378ada --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py @@ -0,0 +1,62 @@ +from diffsynth.pipelines.z_image import ( + ZImagePipeline, ModelConfig, + ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode +) +from modelscope import snapshot_download +from safetensors.torch import save_file +import torch +from PIL import Image + +# Use `vram_config` to enable LoRA hot-loading +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +# Load models +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", origin_file_pattern="model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Load images +snapshot_download( + model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", + allow_file_pattern="assets/style/*", + local_dir="data/style_input" +) +images = [Image.open(f"data/style_input/assets/style/1/{i}.jpg") for i in range(6)] + +# Image to LoRA +with torch.no_grad(): + embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] +save_file(lora, "lora.safetensors") + +# Generate images +prompt = "a cat" +negative_prompt = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符" +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=0, cfg_scale=7, num_inference_steps=50, + positive_only_lora=lora, + sigma_shift=8 +) +image.save("image.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base.py b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base.py new file mode 100644 index 0000000..0af1e53 --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base.py @@ -0,0 +1,33 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +from PIL import Image +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4) +image.save("image_Z-Image-Omni-Base.jpg") + +image = Image.open("image_Z-Image-Omni-Base.jpg") +prompt = "Change the women's clothes to white cheongsam, keep other content unchanged" +image = pipe(prompt=prompt, edit_image=image, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4) +image.save("image_edit_Z-Image-Omni-Base.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py new file mode 100644 index 0000000..cd4276f --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py @@ -0,0 +1,37 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern="data/examples/upscale/low_res.png" +) +controlnet_image = Image.open("data/examples/upscale/low_res.png").resize((1024, 1024)) +prompt = "这是一张充满都市气息的户外人物肖像照片。画面中是一位年轻男性,他展现出时尚而自信的形象。人物拥有精心打理的短发发型,两侧修剪得较短,顶部保留一定长度,呈现出流行的Undercut造型。他佩戴着一副时尚的浅色墨镜或透明镜框眼镜,为整体造型增添了潮流感。脸上洋溢着温和友善的笑容,神情放松自然,给人以阳光开朗的印象。他身穿一件经典的牛仔外套,这件单品永不过时,展现出休闲又有型的穿衣风格。牛仔外套的蓝色调与整体氛围十分协调,领口处隐约可见内搭的衣物。照片的背景是典型的城市街景,可以看到模糊的建筑物、街道和行人,营造出繁华都市的氛围。背景经过了恰当的虚化处理,使人物主体更加突出。光线明亮而柔和,可能是白天的自然光,为照片带来清新通透的视觉效果。整张照片构图专业,景深控制得当,完美捕捉了一个现代都市年轻人充满活力和自信的瞬间,展现出积极向上的生活态度。" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_tile.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py new file mode 100644 index 0000000..f325508 --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py @@ -0,0 +1,50 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Control +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="depth/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024)) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_control.jpg") + +# Inpaint +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="inpaint/*.jpg" +) +inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) +prompt = "一只戴着墨镜的猫" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)]) +image.save("image_inpaint.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py new file mode 100644 index 0000000..6fe170f --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py @@ -0,0 +1,56 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Control +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="depth/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024)) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe( + prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)], + num_inference_steps=30, +) +image.save("image_control.jpg") + +# Inpaint +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="inpaint/*.jpg" +) +inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) +prompt = "一只戴着墨镜的猫" +image = pipe( + prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)], + num_inference_steps=30, +) +image.save("image_inpaint.jpg") diff --git a/examples/z_image/model_training/full/Z-Image-Omni-Base.sh b/examples/z_image/model_training/full/Z-Image-Omni-Base.sh new file mode 100644 index 0000000..cc74b2a --- /dev/null +++ b/examples/z_image/model_training/full/Z-Image-Omni-Base.sh @@ -0,0 +1,32 @@ +# This example is tested on 8*A100 +# Text to image training +accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image-Omni-Base_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 + +# Image(s) to image training +# accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 400 \ +# --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ +# --learning_rate 1e-5 \ +# --num_epochs 2 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/Z-Image-Omni-Base_full_edit" \ +# --trainable_models "dit" \ +# --use_gradient_checkpointing \ +# --dataset_num_workers 8 diff --git a/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh b/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh new file mode 100644 index 0000000..1f0f928 --- /dev/null +++ b/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh @@ -0,0 +1,15 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.controlnet." \ + --output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_full" \ + --trainable_models "controlnet" \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh b/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh new file mode 100644 index 0000000..69d0958 --- /dev/null +++ b/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh @@ -0,0 +1,15 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.controlnet." \ + --output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_full" \ + --trainable_models "controlnet" \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh b/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh new file mode 100644 index 0000000..c56e735 --- /dev/null +++ b/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh @@ -0,0 +1,15 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.controlnet." \ + --output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_full" \ + --trainable_models "controlnet" \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/lora/Z-Image-Omni-Base.sh b/examples/z_image/model_training/lora/Z-Image-Omni-Base.sh new file mode 100644 index 0000000..ef4d524 --- /dev/null +++ b/examples/z_image/model_training/lora/Z-Image-Omni-Base.sh @@ -0,0 +1,35 @@ +# Text to image training +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image-Omni-Base_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 + +# Image(s) to image training +# accelerate launch examples/z_image/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ +# --learning_rate 1e-4 \ +# --num_epochs 5 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/Z-Image-Omni-Base_lora_edit" \ +# --lora_base_model "dit" \ +# --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ +# --lora_rank 32 \ +# --use_gradient_checkpointing \ +# --dataset_num_workers 8 diff --git a/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh b/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh new file mode 100644 index 0000000..9f2032f --- /dev/null +++ b/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh @@ -0,0 +1,17 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ + --lora_rank 32 \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh b/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh new file mode 100644 index 0000000..22c46ce --- /dev/null +++ b/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh @@ -0,0 +1,17 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ + --lora_rank 32 \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh b/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh new file mode 100644 index 0000000..97de2a0 --- /dev/null +++ b/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh @@ -0,0 +1,17 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ + --lora_rank 32 \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/validate_full/Z-Image-Omni-Base.py b/examples/z_image/model_training/validate_full/Z-Image-Omni-Base.py new file mode 100644 index 0000000..efa58db --- /dev/null +++ b/examples/z_image/model_training/validate_full/Z-Image-Omni-Base.py @@ -0,0 +1,33 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +state_dict = load_state_dict("./models/train/Z-Image-Omni-Base_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4) +image.save("image.jpg") + +# Edit +# state_dict = load_state_dict("./models/train/Z-Image-Omni-Base_full_edit/epoch-1.safetensors", torch_dtype=torch.bfloat16) +# pipe.dit.load_state_dict(state_dict) +# prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2." +# images = [ +# Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)), +# Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)), +# ] +# image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4, edit_image=images) +# image.save("image.jpg") diff --git a/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py b/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py new file mode 100644 index 0000000..e3c4d8b --- /dev/null +++ b/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py @@ -0,0 +1,24 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_full/epoch-1.safetensors") +pipe.controlnet.load_state_dict(state_dict) + +controlnet_image = Image.open("data/example_image_dataset/upscale/image_1.jpg").resize((1024, 1024)) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=1)]) +image.save("image_tile.jpg") diff --git a/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py b/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py new file mode 100644 index 0000000..c24fc33 --- /dev/null +++ b/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py @@ -0,0 +1,24 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_full/epoch-1.safetensors") +pipe.controlnet.load_state_dict(state_dict) + +controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024)) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_control.jpg") diff --git a/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py b/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py new file mode 100644 index 0000000..c5712c6 --- /dev/null +++ b/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py @@ -0,0 +1,24 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_full/epoch-1.safetensors") +pipe.controlnet.load_state_dict(state_dict) + +controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024)) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_control.jpg") diff --git a/examples/z_image/model_training/validate_lora/Z-Image-Omni-Base.py b/examples/z_image/model_training/validate_lora/Z-Image-Omni-Base.py new file mode 100644 index 0000000..be144cf --- /dev/null +++ b/examples/z_image/model_training/validate_lora/Z-Image-Omni-Base.py @@ -0,0 +1,31 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +pipe.load_lora(pipe.dit, "./models/train/Z-Image-Omni-Base_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4) +image.save("image.jpg") + +# Edit +# pipe.load_lora(pipe.dit, "./models/train/Z-Image-Omni-Base_lora_edit/epoch-4.safetensors") +# prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2." +# images = [ +# Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)), +# Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)), +# ] +# image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4, edit_image=images) +# image.save("image.jpg") diff --git a/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py b/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py new file mode 100644 index 0000000..b70726a --- /dev/null +++ b/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py @@ -0,0 +1,23 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_lora/epoch-4.safetensors") + +controlnet_image = Image.open("data/example_image_dataset/upscale/image_1.jpg").resize((1024, 1024)) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=1)]) +image.save("image_tile.jpg") diff --git a/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py b/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py new file mode 100644 index 0000000..c66e753 --- /dev/null +++ b/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py @@ -0,0 +1,23 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_lora/epoch-4.safetensors") + +controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024)) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_control.jpg") diff --git a/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py b/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py new file mode 100644 index 0000000..22d48e8 --- /dev/null +++ b/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py @@ -0,0 +1,23 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_lora/epoch-4.safetensors") + +controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024)) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_control.jpg")