From 32449a6aa0166754d2bdf44dea2e6c1bf4ec5899 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 5 Jan 2026 20:04:00 +0800 Subject: [PATCH] support z-image-omni-base training --- diffsynth/core/loader/config.py | 1 + diffsynth/models/siglip2_image_encoder.py | 3 -- diffsynth/models/z_image_dit.py | 2 +- .../model_inference/Z-Image-Omni-Base.py | 24 ++++++++++++++ .../Z-Image-Omni-Base.py | 33 +++++++++++++++++++ .../model_training/full/Z-Image-Omni-Base.sh | 14 ++++++++ .../model_training/lora/Z-Image-Omni-Base.sh | 15 +++++++++ .../validate_full/Z-Image-Omni-Base.py | 21 ++++++++++++ .../validate_lora/Z-Image-Omni-Base.py | 19 +++++++++++ 9 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 examples/z_image/model_inference/Z-Image-Omni-Base.py create mode 100644 examples/z_image/model_inference_low_vram/Z-Image-Omni-Base.py create mode 100644 examples/z_image/model_training/full/Z-Image-Omni-Base.sh create mode 100644 examples/z_image/model_training/lora/Z-Image-Omni-Base.sh create mode 100644 examples/z_image/model_training/validate_full/Z-Image-Omni-Base.py create mode 100644 examples/z_image/model_training/validate_lora/Z-Image-Omni-Base.py diff --git a/diffsynth/core/loader/config.py b/diffsynth/core/loader/config.py index 562675f..88b46a0 100644 --- a/diffsynth/core/loader/config.py +++ b/diffsynth/core/loader/config.py @@ -97,6 +97,7 @@ class ModelConfig: self.reset_local_model_path() if self.require_downloading(): self.download() + if self.path is None: if self.origin_file_pattern is None or self.origin_file_pattern == "": self.path = os.path.join(self.local_model_path, self.model_id) else: diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py index 76441d0..87df855 100644 --- a/diffsynth/models/siglip2_image_encoder.py +++ b/diffsynth/models/siglip2_image_encoder.py @@ -90,12 +90,10 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel): super().__init__(config) self.processor = Siglip2ImageProcessorFast( **{ - "crop_size": None, "data_format": "channels_first", "default_to_square": True, "device": None, "disable_grouping": None, - "do_center_crop": None, "do_convert_rgb": None, "do_normalize": True, "do_pad": None, @@ -120,7 +118,6 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel): "resample": 2, "rescale_factor": 0.00392156862745098, "return_tensors": None, - "size": None } ) diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index d0e392e..d141e02 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -626,7 +626,7 @@ class ZImageDiT(nn.Module): # Pad token feats_cat = torch.cat(feats, dim=0) - feats_cat[torch.cat(inner_pad_mask)] = pad_token + feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device) feats = list(feats_cat.split(item_seqlens, dim=0)) # RoPE diff --git a/examples/z_image/model_inference/Z-Image-Omni-Base.py b/examples/z_image/model_inference/Z-Image-Omni-Base.py new file mode 100644 index 0000000..b1d2217 --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-Omni-Base.py @@ -0,0 +1,24 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +from PIL import Image +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4) +image.save("image_Z-Image-Omni-Base.jpg") + +image = Image.open("image_Z-Image-Omni-Base.jpg") +prompt = "Change the women's clothes to white cheongsam, keep other content unchanged" +image = pipe(prompt=prompt, edit_image=image, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4) +image.save("image_edit_Z-Image-Omni-Base.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base.py b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base.py new file mode 100644 index 0000000..0af1e53 --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base.py @@ -0,0 +1,33 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +from PIL import Image +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4) +image.save("image_Z-Image-Omni-Base.jpg") + +image = Image.open("image_Z-Image-Omni-Base.jpg") +prompt = "Change the women's clothes to white cheongsam, keep other content unchanged" +image = pipe(prompt=prompt, edit_image=image, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4) +image.save("image_edit_Z-Image-Omni-Base.jpg") diff --git a/examples/z_image/model_training/full/Z-Image-Omni-Base.sh b/examples/z_image/model_training/full/Z-Image-Omni-Base.sh new file mode 100644 index 0000000..4f2d1da --- /dev/null +++ b/examples/z_image/model_training/full/Z-Image-Omni-Base.sh @@ -0,0 +1,14 @@ +# This example is tested on 8*A100 +accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image-Omni-Base_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/lora/Z-Image-Omni-Base.sh b/examples/z_image/model_training/lora/Z-Image-Omni-Base.sh new file mode 100644 index 0000000..3f0b158 --- /dev/null +++ b/examples/z_image/model_training/lora/Z-Image-Omni-Base.sh @@ -0,0 +1,15 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image-Omni-Base_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/validate_full/Z-Image-Omni-Base.py b/examples/z_image/model_training/validate_full/Z-Image-Omni-Base.py new file mode 100644 index 0000000..b51095c --- /dev/null +++ b/examples/z_image/model_training/validate_full/Z-Image-Omni-Base.py @@ -0,0 +1,21 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/Z-Image-Omni-Base_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4) +image.save("image.jpg") diff --git a/examples/z_image/model_training/validate_lora/Z-Image-Omni-Base.py b/examples/z_image/model_training/validate_lora/Z-Image-Omni-Base.py new file mode 100644 index 0000000..77ee72f --- /dev/null +++ b/examples/z_image/model_training/validate_lora/Z-Image-Omni-Base.py @@ -0,0 +1,19 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/Z-Image-Omni-Base_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4) +image.save("image.jpg")