From bac39b1cd281ad0c464a95c4eec7f74237bec710 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 7 Jan 2026 15:56:53 +0800 Subject: [PATCH] support z-image controlnet --- diffsynth/configs/model_configs.py | 6 + .../configs/vram_management_module_maps.py | 4 + diffsynth/models/z_image_controlnet.py | 154 ++++++++++++++++++ diffsynth/models/z_image_dit.py | 66 ++++++++ diffsynth/pipelines/z_image.py | 93 +++++++++-- .../utils/controlnet/controlnet_input.py | 1 + ...ge-Turbo-Fun-Controlnet-Tile-2.1-8steps.py | 27 +++ ...e-Turbo-Fun-Controlnet-Union-2.1-8steps.py | 40 +++++ .../Z-Image-Turbo-Fun-Controlnet-Union-2.1.py | 46 ++++++ ...ge-Turbo-Fun-Controlnet-Tile-2.1-8steps.py | 37 +++++ ...e-Turbo-Fun-Controlnet-Union-2.1-8steps.py | 50 ++++++ .../Z-Image-Turbo-Fun-Controlnet-Union-2.1.py | 56 +++++++ .../model_training/full/Z-Image-Omni-Base.sh | 18 ++ ...ge-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh | 15 ++ ...e-Turbo-Fun-Controlnet-Union-2.1-8steps.sh | 15 ++ .../Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh | 15 ++ .../model_training/lora/Z-Image-Omni-Base.sh | 20 +++ ...ge-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh | 17 ++ ...e-Turbo-Fun-Controlnet-Union-2.1-8steps.sh | 17 ++ .../Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh | 17 ++ .../validate_full/Z-Image-Omni-Base.py | 12 ++ ...ge-Turbo-Fun-Controlnet-Tile-2.1-8steps.py | 24 +++ ...e-Turbo-Fun-Controlnet-Union-2.1-8steps.py | 24 +++ .../Z-Image-Turbo-Fun-Controlnet-Union-2.1.py | 24 +++ .../validate_lora/Z-Image-Omni-Base.py | 12 ++ ...ge-Turbo-Fun-Controlnet-Tile-2.1-8steps.py | 23 +++ ...e-Turbo-Fun-Controlnet-Union-2.1-8steps.py | 23 +++ .../Z-Image-Turbo-Fun-Controlnet-Union-2.1.py | 23 +++ 28 files changed, 868 insertions(+), 11 deletions(-) create mode 100644 diffsynth/models/z_image_controlnet.py create mode 100644 examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py create mode 100644 examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py create mode 100644 examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py create mode 100644 examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py create mode 100644 examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py create mode 100644 examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py create mode 100644 examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh create mode 100644 examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh create mode 100644 examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh create mode 100644 examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh create mode 100644 examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh create mode 100644 examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh create mode 100644 examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py create mode 100644 examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py create mode 100644 examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py create mode 100644 examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py create mode 100644 examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py create mode 100644 examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index e6c7741..0b6c61a 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -540,6 +540,12 @@ z_image_series = [ "model_name": "siglip_vision_model_428m", "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M", }, + { + # ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors") + "model_hash": "1677708d40029ab380a95f6c731a57d7", + "model_name": "z_image_controlnet", + "model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet", + } ] MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index 5f1b595..6a6b6c4 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -195,4 +195,8 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", }, + "diffsynth.models.z_image_controlnet.ZImageControlNet": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, } diff --git a/diffsynth/models/z_image_controlnet.py b/diffsynth/models/z_image_controlnet.py new file mode 100644 index 0000000..5105534 --- /dev/null +++ b/diffsynth/models/z_image_controlnet.py @@ -0,0 +1,154 @@ +from .z_image_dit import ZImageTransformerBlock +from ..core.gradient import gradient_checkpoint_forward +from torch.nn.utils.rnn import pad_sequence +import torch +from torch import nn + + +class ZImageControlTransformerBlock(ZImageTransformerBlock): + def __init__( + self, + layer_id: int = 1000, + dim: int = 3840, + n_heads: int = 30, + n_kv_heads: int = 30, + norm_eps: float = 1e-5, + qk_norm: bool = True, + modulation = True, + block_id = 0 + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(self.dim, self.dim) + self.after_proj = nn.Linear(self.dim, self.dim) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class ZImageControlNet(torch.nn.Module): + def __init__( + self, + control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28), + control_in_dim=33, + dim=3840, + n_refiner_layers=2, + ): + super().__init__() + self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places]) + self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)}) + self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)]) + self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14} + + def forward_layers( + self, + x, + cap_feats, + control_context, + control_context_item_seqlens, + kwargs, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + bsz = len(control_context) + # unified + cap_item_seqlens = [len(_) for _ in cap_feats] + control_context_unified = [] + for i in range(bsz): + control_context_len = control_context_item_seqlens[i] + cap_len = cap_item_seqlens[i] + control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]])) + c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + for layer in self.control_layers: + c = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + c=c, **new_kwargs + ) + + hints = torch.unbind(c)[:-1] + return hints + + def forward_refiner( + self, + dit, + x, + cap_feats, + control_context, + kwargs, + t=None, + patch_size=2, + f_patch_size=1, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + # embeddings + bsz = len(control_context) + device = control_context[0].device + ( + control_context, + control_context_size, + control_context_pos_ids, + control_context_inner_pad_mask, + ) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0)) + + # control_context embed & refine + control_context_item_seqlens = [len(_) for _ in control_context] + assert all(_ % 2 == 0 for _ in control_context_item_seqlens) + control_context_max_item_seqlen = max(control_context_item_seqlens) + + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + # Match t_embedder output dtype to control_context for layerwise casting compatibility + adaln_input = t.type_as(control_context) + control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device) + control_context = list(control_context.split(control_context_item_seqlens, dim=0)) + control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0) + control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(control_context_item_seqlens): + control_context_attn_mask[i, :seq_len] = 1 + c = control_context + + # arguments + new_kwargs = dict( + x=x, + attn_mask=control_context_attn_mask, + freqs_cis=control_context_freqs_cis, + adaln_input=adaln_input, + ) + new_kwargs.update(kwargs) + + for layer in self.control_noise_refiner: + c = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + c=c, **new_kwargs + ) + + hints = torch.unbind(c)[:-1] + control_context = torch.unbind(c)[-1] + + return hints, control_context, control_context_item_seqlens \ No newline at end of file diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index d141e02..9744ddb 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -609,6 +609,72 @@ class ZImageDiT(nn.Module): # all_img_pad_mask, # all_cap_pad_mask, # ) + + def patchify_controlnet( + self, + all_image: List[torch.Tensor], + patch_size: int = 2, + f_patch_size: int = 1, + cap_padding_len: int = None, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + + for i, image in enumerate(all_image): + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_image_size, + all_image_pos_ids, + all_image_pad_mask, + ) def _prepare_sequence( self, diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index d119cbf..23d94ec 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -16,6 +16,7 @@ from ..models.z_image_text_encoder import ZImageTextEncoder from ..models.z_image_dit import ZImageDiT from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M +from ..models.z_image_controlnet import ZImageControlNet class ZImagePipeline(BasePipeline): @@ -31,8 +32,9 @@ class ZImagePipeline(BasePipeline): self.vae_encoder: FluxVAEEncoder = None self.vae_decoder: FluxVAEDecoder = None self.image_encoder: Siglip2ImageEncoder428M = None + self.controlnet: ZImageControlNet = None self.tokenizer: AutoTokenizer = None - self.in_iteration_models = ("dit",) + self.in_iteration_models = ("dit", "controlnet") self.units = [ ZImageUnit_ShapeChecker(), ZImageUnit_PromptEmbedder(), @@ -41,6 +43,7 @@ class ZImagePipeline(BasePipeline): ZImageUnit_EditImageAutoResize(), ZImageUnit_EditImageEmbedderVAE(), ZImageUnit_EditImageEmbedderSiglip(), + ZImageUnit_PAIControlNet(), ] self.model_fn = model_fn_z_image @@ -63,6 +66,7 @@ class ZImagePipeline(BasePipeline): pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m") + pipe.controlnet = model_pool.fetch_model("z_image_controlnet") if tokenizer_config is not None: tokenizer_config.download_if_necessary() pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) @@ -94,6 +98,8 @@ class ZImagePipeline(BasePipeline): # Steps num_inference_steps: int = 8, sigma_shift: float = None, + # ControlNet + controlnet_inputs: List[ControlNetInput] = None, # Progress bar progress_bar_cmd = tqdm, ): @@ -114,6 +120,7 @@ class ZImagePipeline(BasePipeline): "seed": seed, "rand_device": rand_device, "num_inference_steps": num_inference_steps, "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, + "controlnet_inputs": controlnet_inputs, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -331,7 +338,9 @@ class ZImageUnit_EditImageAutoResize(PipelineUnit): if edit_image_auto_resize is None or not edit_image_auto_resize: return {} operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16) - edit_image = operator(edit_image) + if not isinstance(edit_image, list): + edit_image = [edit_image] + edit_image = [operator(i) for i in edit_image] return {"edit_image": edit_image} @@ -376,8 +385,49 @@ class ZImageUnit_EditImageEmbedderVAE(PipelineUnit): return {"image_latents": image_latents} +class ZImageUnit_PAIControlNet(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("controlnet_inputs", "height", "width"), + output_params=("control_context", "control_scale"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width): + if controlnet_inputs is None: + return {} + if len(controlnet_inputs) != 1: + print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.") + controlnet_input = controlnet_inputs[0] + pipe.load_models_to_device(self.onload_model_names) + + control_image = controlnet_input.image + if control_image is not None: + control_image = pipe.preprocess_image(control_image) + control_latents = pipe.vae_encoder(control_image) + else: + control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1 + + inpaint_mask = controlnet_input.inpaint_mask + if inpaint_mask is not None: + inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1) + inpaint_image = controlnet_input.inpaint_image + inpaint_image = pipe.preprocess_image(inpaint_image) + inpaint_image = inpaint_image * (inpaint_mask < 0.5) + inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1] + else: + inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) + inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device) + inpaint_latent = pipe.vae_encoder(inpaint_image) + + control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1) + control_context = rearrange(control_context, "B C H W -> B C 1 H W") + return {"control_context": control_context, "control_scale": controlnet_input.scale} + + def model_fn_z_image( dit: ZImageDiT, + controlnet: ZImageControlNet = None, latents=None, timestep=None, prompt_embeds=None, @@ -393,13 +443,14 @@ def model_fn_z_image( if dit.siglip_embedder is None: return model_fn_z_image_turbo( dit, - latents, - timestep, - prompt_embeds, - image_embeds, - image_latents, - use_gradient_checkpointing, - use_gradient_checkpointing_offload, + controlnet=controlnet, + latents=latents, + timestep=timestep, + prompt_embeds=prompt_embeds, + image_embeds=image_embeds, + image_latents=image_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, **kwargs, ) latents = [rearrange(latents, "B C H W -> C B H W")] @@ -431,11 +482,14 @@ def model_fn_z_image( def model_fn_z_image_turbo( dit: ZImageDiT, + controlnet: ZImageControlNet = None, latents=None, timestep=None, prompt_embeds=None, image_embeds=None, image_latents=None, + control_context=None, + control_scale=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, @@ -460,11 +514,17 @@ def model_fn_z_image_turbo( # Noise refine x = dit.all_x_embedder["2-1"](x) + x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device) x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0)) x = rearrange(x, "L C -> 1 L C") x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C") + + if control_context is not None: + kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy) + refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner( + dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1) - for layer in dit.noise_refiner: + for layer_id, layer in enumerate(dit.noise_refiner): x = gradient_checkpoint_forward( layer, use_gradient_checkpointing=use_gradient_checkpointing, @@ -474,6 +534,8 @@ def model_fn_z_image_turbo( freqs_cis=x_freqs_cis, adaln_input=t_noisy, ) + if control_context is not None: + x = x + refiner_hints[layer_id] * control_scale # Prompt refine cap_feats = dit.cap_embedder(cap_feats) @@ -495,7 +557,13 @@ def model_fn_z_image_turbo( # Unified unified = torch.cat([x, cap_feats], dim=1) unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1) - for layer in dit.layers: + + if control_context is not None: + kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy) + hints = controlnet.forward_layers( + unified, cap_feats, control_context, control_context_item_seqlens, kwargs) + + for layer_id, layer in enumerate(dit.layers): unified = gradient_checkpoint_forward( layer, use_gradient_checkpointing=use_gradient_checkpointing, @@ -505,6 +573,9 @@ def model_fn_z_image_turbo( freqs_cis=unified_freqs_cis, adaln_input=t_noisy, ) + if control_context is not None: + if layer_id in controlnet.control_layers_mapping: + unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale # Output unified = dit.all_final_layer["2-1"](unified, t_noisy) diff --git a/diffsynth/utils/controlnet/controlnet_input.py b/diffsynth/utils/controlnet/controlnet_input.py index 1a2949b..a79064b 100644 --- a/diffsynth/utils/controlnet/controlnet_input.py +++ b/diffsynth/utils/controlnet/controlnet_input.py @@ -9,5 +9,6 @@ class ControlNetInput: start: float = 1.0 end: float = 0.0 image: Image.Image = None + inpaint_image: Image.Image = None inpaint_mask: Image.Image = None processor_id: str = None diff --git a/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py b/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py new file mode 100644 index 0000000..21b387e --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern="data/examples/upscale/low_res.png" +) +controlnet_image = Image.open("data/examples/upscale/low_res.png").resize((1024, 1024)) +prompt = "这是一张充满都市气息的户外人物肖像照片。画面中是一位年轻男性,他展现出时尚而自信的形象。人物拥有精心打理的短发发型,两侧修剪得较短,顶部保留一定长度,呈现出流行的Undercut造型。他佩戴着一副时尚的浅色墨镜或透明镜框眼镜,为整体造型增添了潮流感。脸上洋溢着温和友善的笑容,神情放松自然,给人以阳光开朗的印象。他身穿一件经典的牛仔外套,这件单品永不过时,展现出休闲又有型的穿衣风格。牛仔外套的蓝色调与整体氛围十分协调,领口处隐约可见内搭的衣物。照片的背景是典型的城市街景,可以看到模糊的建筑物、街道和行人,营造出繁华都市的氛围。背景经过了恰当的虚化处理,使人物主体更加突出。光线明亮而柔和,可能是白天的自然光,为照片带来清新通透的视觉效果。整张照片构图专业,景深控制得当,完美捕捉了一个现代都市年轻人充满活力和自信的瞬间,展现出积极向上的生活态度。" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_tile.jpg") diff --git a/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py b/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py new file mode 100644 index 0000000..54adbea --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py @@ -0,0 +1,40 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Control +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="depth/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024)) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_control.jpg") + +# Inpaint +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="inpaint/*.jpg" +) +inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) +prompt = "一只戴着墨镜的猫" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)]) +image.save("image_inpaint.jpg") diff --git a/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py b/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py new file mode 100644 index 0000000..2f872d0 --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py @@ -0,0 +1,46 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Control +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="depth/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024)) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe( + prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)], + num_inference_steps=30, +) +image.save("image_control.jpg") + +# Inpaint +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="inpaint/*.jpg" +) +inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) +prompt = "一只戴着墨镜的猫" +image = pipe( + prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)], + num_inference_steps=30, +) +image.save("image_inpaint.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py new file mode 100644 index 0000000..cd4276f --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py @@ -0,0 +1,37 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern="data/examples/upscale/low_res.png" +) +controlnet_image = Image.open("data/examples/upscale/low_res.png").resize((1024, 1024)) +prompt = "这是一张充满都市气息的户外人物肖像照片。画面中是一位年轻男性,他展现出时尚而自信的形象。人物拥有精心打理的短发发型,两侧修剪得较短,顶部保留一定长度,呈现出流行的Undercut造型。他佩戴着一副时尚的浅色墨镜或透明镜框眼镜,为整体造型增添了潮流感。脸上洋溢着温和友善的笑容,神情放松自然,给人以阳光开朗的印象。他身穿一件经典的牛仔外套,这件单品永不过时,展现出休闲又有型的穿衣风格。牛仔外套的蓝色调与整体氛围十分协调,领口处隐约可见内搭的衣物。照片的背景是典型的城市街景,可以看到模糊的建筑物、街道和行人,营造出繁华都市的氛围。背景经过了恰当的虚化处理,使人物主体更加突出。光线明亮而柔和,可能是白天的自然光,为照片带来清新通透的视觉效果。整张照片构图专业,景深控制得当,完美捕捉了一个现代都市年轻人充满活力和自信的瞬间,展现出积极向上的生活态度。" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_tile.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py new file mode 100644 index 0000000..f325508 --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py @@ -0,0 +1,50 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Control +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="depth/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024)) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_control.jpg") + +# Inpaint +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="inpaint/*.jpg" +) +inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) +prompt = "一只戴着墨镜的猫" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)]) +image.save("image_inpaint.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py new file mode 100644 index 0000000..6fe170f --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py @@ -0,0 +1,56 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Control +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="depth/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024)) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe( + prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)], + num_inference_steps=30, +) +image.save("image_control.jpg") + +# Inpaint +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="inpaint/*.jpg" +) +inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) +prompt = "一只戴着墨镜的猫" +image = pipe( + prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)], + num_inference_steps=30, +) +image.save("image_inpaint.jpg") diff --git a/examples/z_image/model_training/full/Z-Image-Omni-Base.sh b/examples/z_image/model_training/full/Z-Image-Omni-Base.sh index 4f2d1da..cc74b2a 100644 --- a/examples/z_image/model_training/full/Z-Image-Omni-Base.sh +++ b/examples/z_image/model_training/full/Z-Image-Omni-Base.sh @@ -1,4 +1,5 @@ # This example is tested on 8*A100 +# Text to image training accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata.csv \ @@ -12,3 +13,20 @@ accelerate launch --config_file examples/z_image/model_training/full/accelerate_ --trainable_models "dit" \ --use_gradient_checkpointing \ --dataset_num_workers 8 + +# Image(s) to image training +# accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 400 \ +# --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ +# --learning_rate 1e-5 \ +# --num_epochs 2 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/Z-Image-Omni-Base_full_edit" \ +# --trainable_models "dit" \ +# --use_gradient_checkpointing \ +# --dataset_num_workers 8 diff --git a/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh b/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh new file mode 100644 index 0000000..1f0f928 --- /dev/null +++ b/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh @@ -0,0 +1,15 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.controlnet." \ + --output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_full" \ + --trainable_models "controlnet" \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh b/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh new file mode 100644 index 0000000..69d0958 --- /dev/null +++ b/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh @@ -0,0 +1,15 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.controlnet." \ + --output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_full" \ + --trainable_models "controlnet" \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh b/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh new file mode 100644 index 0000000..c56e735 --- /dev/null +++ b/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh @@ -0,0 +1,15 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.controlnet." \ + --output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_full" \ + --trainable_models "controlnet" \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/lora/Z-Image-Omni-Base.sh b/examples/z_image/model_training/lora/Z-Image-Omni-Base.sh index 3f0b158..ef4d524 100644 --- a/examples/z_image/model_training/lora/Z-Image-Omni-Base.sh +++ b/examples/z_image/model_training/lora/Z-Image-Omni-Base.sh @@ -1,3 +1,4 @@ +# Text to image training accelerate launch examples/z_image/model_training/train.py \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata.csv \ @@ -13,3 +14,22 @@ accelerate launch examples/z_image/model_training/train.py \ --lora_rank 32 \ --use_gradient_checkpointing \ --dataset_num_workers 8 + +# Image(s) to image training +# accelerate launch examples/z_image/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ +# --learning_rate 1e-4 \ +# --num_epochs 5 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/Z-Image-Omni-Base_lora_edit" \ +# --lora_base_model "dit" \ +# --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ +# --lora_rank 32 \ +# --use_gradient_checkpointing \ +# --dataset_num_workers 8 diff --git a/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh b/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh new file mode 100644 index 0000000..9f2032f --- /dev/null +++ b/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh @@ -0,0 +1,17 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ + --lora_rank 32 \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh b/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh new file mode 100644 index 0000000..22c46ce --- /dev/null +++ b/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh @@ -0,0 +1,17 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ + --lora_rank 32 \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh b/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh new file mode 100644 index 0000000..97de2a0 --- /dev/null +++ b/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh @@ -0,0 +1,17 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ + --lora_rank 32 \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/validate_full/Z-Image-Omni-Base.py b/examples/z_image/model_training/validate_full/Z-Image-Omni-Base.py index b51095c..efa58db 100644 --- a/examples/z_image/model_training/validate_full/Z-Image-Omni-Base.py +++ b/examples/z_image/model_training/validate_full/Z-Image-Omni-Base.py @@ -14,8 +14,20 @@ pipe = ZImagePipeline.from_pretrained( ], tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), ) + state_dict = load_state_dict("./models/train/Z-Image-Omni-Base_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) pipe.dit.load_state_dict(state_dict) prompt = "a dog" image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4) image.save("image.jpg") + +# Edit +# state_dict = load_state_dict("./models/train/Z-Image-Omni-Base_full_edit/epoch-1.safetensors", torch_dtype=torch.bfloat16) +# pipe.dit.load_state_dict(state_dict) +# prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2." +# images = [ +# Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)), +# Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)), +# ] +# image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4, edit_image=images) +# image.save("image.jpg") diff --git a/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py b/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py new file mode 100644 index 0000000..e3c4d8b --- /dev/null +++ b/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py @@ -0,0 +1,24 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_full/epoch-1.safetensors") +pipe.controlnet.load_state_dict(state_dict) + +controlnet_image = Image.open("data/example_image_dataset/upscale/image_1.jpg").resize((1024, 1024)) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=1)]) +image.save("image_tile.jpg") diff --git a/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py b/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py new file mode 100644 index 0000000..c24fc33 --- /dev/null +++ b/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py @@ -0,0 +1,24 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_full/epoch-1.safetensors") +pipe.controlnet.load_state_dict(state_dict) + +controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024)) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_control.jpg") diff --git a/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py b/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py new file mode 100644 index 0000000..c5712c6 --- /dev/null +++ b/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py @@ -0,0 +1,24 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_full/epoch-1.safetensors") +pipe.controlnet.load_state_dict(state_dict) + +controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024)) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_control.jpg") diff --git a/examples/z_image/model_training/validate_lora/Z-Image-Omni-Base.py b/examples/z_image/model_training/validate_lora/Z-Image-Omni-Base.py index 77ee72f..be144cf 100644 --- a/examples/z_image/model_training/validate_lora/Z-Image-Omni-Base.py +++ b/examples/z_image/model_training/validate_lora/Z-Image-Omni-Base.py @@ -1,4 +1,5 @@ from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +from PIL import Image import torch @@ -13,7 +14,18 @@ pipe = ZImagePipeline.from_pretrained( ], tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), ) + pipe.load_lora(pipe.dit, "./models/train/Z-Image-Omni-Base_lora/epoch-4.safetensors") prompt = "a dog" image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4) image.save("image.jpg") + +# Edit +# pipe.load_lora(pipe.dit, "./models/train/Z-Image-Omni-Base_lora_edit/epoch-4.safetensors") +# prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2." +# images = [ +# Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)), +# Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)), +# ] +# image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4, edit_image=images) +# image.save("image.jpg") diff --git a/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py b/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py new file mode 100644 index 0000000..b70726a --- /dev/null +++ b/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py @@ -0,0 +1,23 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_lora/epoch-4.safetensors") + +controlnet_image = Image.open("data/example_image_dataset/upscale/image_1.jpg").resize((1024, 1024)) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=1)]) +image.save("image_tile.jpg") diff --git a/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py b/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py new file mode 100644 index 0000000..c66e753 --- /dev/null +++ b/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py @@ -0,0 +1,23 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_lora/epoch-4.safetensors") + +controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024)) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_control.jpg") diff --git a/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py b/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py new file mode 100644 index 0000000..22d48e8 --- /dev/null +++ b/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py @@ -0,0 +1,23 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_lora/epoch-4.safetensors") + +controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024)) +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)]) +image.save("image_control.jpg")