From dd479e5bffb7a214183431ca5c6fb3a892bb3e63 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 7 Jan 2026 20:36:53 +0800 Subject: [PATCH] support z-image-omni-base-i2L --- diffsynth/configs/model_configs.py | 11 +- diffsynth/diffusion/base_pipeline.py | 14 +- diffsynth/models/z_image_image2lora.py | 189 ++++++++++++++++++ diffsynth/pipelines/z_image.py | 81 +++++++- .../model_inference/Z-Image-Omni-Base-i2L.py | 62 ++++++ .../Z-Image-Omni-Base-i2L.py | 62 ++++++ 6 files changed, 413 insertions(+), 6 deletions(-) create mode 100644 diffsynth/models/z_image_image2lora.py create mode 100644 examples/z_image/model_inference/Z-Image-Omni-Base-i2L.py create mode 100644 examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 0b6c61a..1f6e2d1 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -541,11 +541,18 @@ z_image_series = [ "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M", }, { - # ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors") + # Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors") "model_hash": "1677708d40029ab380a95f6c731a57d7", "model_name": "z_image_controlnet", "model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet", - } + }, + { + # Example: ??? + "model_hash": "9510cb8cd1dd34ee0e4f111c24905510", + "model_name": "z_image_image2lora_style", + "model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel", + "extra_kwargs": {"compress_dim": 128}, + }, ] MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index fa355a1..e37f381 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -235,6 +235,7 @@ class BasePipeline(torch.nn.Module): alpha=1, hotload=None, state_dict=None, + verbose=1, ): if state_dict is None: if isinstance(lora_config, str): @@ -261,12 +262,13 @@ class BasePipeline(torch.nn.Module): updated_num += 1 module.lora_A_weights.append(lora[lora_a_name] * alpha) module.lora_B_weights.append(lora[lora_b_name]) - print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.") + if verbose >= 1: + print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.") else: lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha) - def clear_lora(self): + def clear_lora(self, verbose=1): cleared_num = 0 for name, module in self.named_modules(): if isinstance(module, AutoWrappedLinear): @@ -276,7 +278,8 @@ class BasePipeline(torch.nn.Module): module.lora_A_weights.clear() if hasattr(module, "lora_B_weights"): module.lora_B_weights.clear() - print(f"{cleared_num} LoRA layers are cleared.") + if verbose >= 1: + print(f"{cleared_num} LoRA layers are cleared.") def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None): @@ -304,8 +307,13 @@ class BasePipeline(torch.nn.Module): def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0) noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) if cfg_scale != 1.0: + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: diff --git a/diffsynth/models/z_image_image2lora.py b/diffsynth/models/z_image_image2lora.py new file mode 100644 index 0000000..757f3f6 --- /dev/null +++ b/diffsynth/models/z_image_image2lora.py @@ -0,0 +1,189 @@ +import torch +from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP + + +class LoRATrainerBlock(torch.nn.Module): + def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"): + super().__init__() + self.prefix = prefix + self.lora_patterns = lora_patterns + self.block_id = block_id + self.layers = [] + for name, lora_a_dim, lora_b_dim in self.lora_patterns: + self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank)) + self.layers = torch.nn.ModuleList(self.layers) + if use_residual: + self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim) + else: + self.proj_residual = None + + def forward(self, x, residual=None): + lora = {} + if self.proj_residual is not None: residual = self.proj_residual(residual) + for lora_pattern, layer in zip(self.lora_patterns, self.layers): + name = lora_pattern[0] + lora_a, lora_b = layer(x, residual=residual) + lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a + lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b + return lora + + +class ZImageImage2LoRAComponent(torch.nn.Module): + def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = lora_patterns + self.num_blocks = num_blocks + self.blocks = [] + for lora_patterns in self.lora_patterns: + for block_id in range(self.num_blocks): + self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix)) + self.blocks = torch.nn.ModuleList(self.blocks) + self.residual_scale = 0.05 + self.use_residual = use_residual + + def forward(self, x, residual=None): + if residual is not None: + if self.use_residual: + residual = residual * self.residual_scale + else: + residual = None + lora = {} + for block in self.blocks: + lora.update(block(x, residual)) + return lora + + +class ZImageImage2LoRAModel(torch.nn.Module): + def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + lora_patterns = [ + [ + ("attention.to_q", 3840, 3840), + ("attention.to_k", 3840, 3840), + ("attention.to_v", 3840, 3840), + ("attention.to_out.0", 3840, 3840), + ], + [ + ("feed_forward.w1", 3840, 10240), + ("feed_forward.w2", 10240, 3840), + ("feed_forward.w3", 3840, 10240), + ], + ] + config = { + "lora_patterns": lora_patterns, + "use_residual": use_residual, + "compress_dim": compress_dim, + "rank": rank, + "residual_length": residual_length, + "residual_mid_dim": residual_mid_dim, + } + self.layers_lora = ZImageImage2LoRAComponent( + prefix="layers", + num_blocks=30, + **config, + ) + self.context_refiner_lora = ZImageImage2LoRAComponent( + prefix="context_refiner", + num_blocks=2, + **config, + ) + self.noise_refiner_lora = ZImageImage2LoRAComponent( + prefix="noise_refiner", + num_blocks=2, + **config, + ) + + def forward(self, x, residual=None): + lora = {} + lora.update(self.layers_lora(x, residual=residual)) + lora.update(self.context_refiner_lora(x, residual=residual)) + lora.update(self.noise_refiner_lora(x, residual=residual)) + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if ".proj_a." in name: + state_dict[name] = state_dict[name] * 0.3 + elif ".proj_b.proj_out." in name: + state_dict[name] = state_dict[name] * 0 + elif ".proj_residual.proj_out." in name: + state_dict[name] = state_dict[name] * 0.3 + self.load_state_dict(state_dict) + + +class ImageEmb2LoRAWeightCompressed(torch.nn.Module): + def __init__(self, in_dim, out_dim, emb_dim, rank): + super().__init__() + self.lora_a = torch.nn.Parameter(torch.randn((rank, in_dim))) + self.lora_b = torch.nn.Parameter(torch.randn((out_dim, rank))) + self.proj = torch.nn.Linear(emb_dim, rank * rank, bias=True) + self.rank = rank + + def forward(self, x): + x = self.proj(x).view(self.rank, self.rank) + lora_a = x @ self.lora_a + lora_b = self.lora_b + return lora_a, lora_b + + +class ZImageImage2LoRAModelCompressed(torch.nn.Module): + def __init__(self, emb_dim=1536+4096, rank=32): + super().__init__() + target_layers = [ + ("attention.to_q", 3840, 3840), + ("attention.to_k", 3840, 3840), + ("attention.to_v", 3840, 3840), + ("attention.to_out.0", 3840, 3840), + ("feed_forward.w1", 3840, 10240), + ("feed_forward.w2", 10240, 3840), + ("feed_forward.w3", 3840, 10240), + ] + self.lora_patterns = [ + { + "prefix": "layers", + "num_layers": 30, + "target_layers": target_layers, + }, + { + "prefix": "context_refiner", + "num_layers": 2, + "target_layers": target_layers, + }, + { + "prefix": "noise_refiner", + "num_layers": 2, + "target_layers": target_layers, + }, + ] + module_dict = {} + for lora_pattern in self.lora_patterns: + prefix, num_layers, target_layers = lora_pattern["prefix"], lora_pattern["num_layers"], lora_pattern["target_layers"] + for layer_id in range(num_layers): + for layer_name, in_dim, out_dim in target_layers: + name = f"{prefix}.{layer_id}.{layer_name}".replace(".", "___") + model = ImageEmb2LoRAWeightCompressed(in_dim, out_dim, emb_dim, rank) + module_dict[name] = model + self.module_dict = torch.nn.ModuleDict(module_dict) + + def forward(self, x, residual=None): + lora = {} + for name, module in self.module_dict.items(): + name = name.replace("___", ".") + name_a, name_b = f"{name}.lora_A.default.weight", f"{name}.lora_B.default.weight" + lora_a, lora_b = module(x) + lora[name_a] = lora_a + lora[name_b] = lora_b + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if "lora_b" in name: + state_dict[name] = state_dict[name] * 0 + elif "lora_a" in name: + state_dict[name] = state_dict[name] * 0.2 + elif "proj.weight" in name: + print(name) + state_dict[name] = state_dict[name] * 0.2 + self.load_state_dict(state_dict) diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index 23d94ec..df6d0aa 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -4,12 +4,13 @@ from typing import Union from tqdm import tqdm from einops import rearrange import numpy as np -from typing import Union, List, Optional, Tuple, Iterable +from typing import Union, List, Optional, Tuple, Iterable, Dict from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..core.data.operators import ImageCropAndResize from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora import merge_lora from transformers import AutoTokenizer from ..models.z_image_text_encoder import ZImageTextEncoder @@ -17,6 +18,9 @@ from ..models.z_image_dit import ZImageDiT from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M from ..models.z_image_controlnet import ZImageControlNet +from ..models.siglip2_image_encoder import Siglip2ImageEncoder +from ..models.dinov3_image_encoder import DINOv3ImageEncoder +from ..models.z_image_image2lora import ZImageImage2LoRAModel class ZImagePipeline(BasePipeline): @@ -33,6 +37,9 @@ class ZImagePipeline(BasePipeline): self.vae_decoder: FluxVAEDecoder = None self.image_encoder: Siglip2ImageEncoder428M = None self.controlnet: ZImageControlNet = None + self.siglip2_image_encoder: Siglip2ImageEncoder = None + self.dinov3_image_encoder: DINOv3ImageEncoder = None + self.image2lora_style: ZImageImage2LoRAModel = None self.tokenizer: AutoTokenizer = None self.in_iteration_models = ("dit", "controlnet") self.units = [ @@ -67,6 +74,9 @@ class ZImagePipeline(BasePipeline): pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m") pipe.controlnet = model_pool.fetch_model("z_image_controlnet") + pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder") + pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder") + pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style") if tokenizer_config is not None: tokenizer_config.download_if_necessary() pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) @@ -100,6 +110,9 @@ class ZImagePipeline(BasePipeline): sigma_shift: float = None, # ControlNet controlnet_inputs: List[ControlNetInput] = None, + # Image to LoRA + image2lora_images: List[Image.Image] = None, + positive_only_lora: Dict[str, torch.Tensor] = None, # Progress bar progress_bar_cmd = tqdm, ): @@ -121,6 +134,7 @@ class ZImagePipeline(BasePipeline): "num_inference_steps": num_inference_steps, "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "controlnet_inputs": controlnet_inputs, + "image2lora_images": image2lora_images, "positive_only_lora": positive_only_lora, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -480,6 +494,71 @@ def model_fn_z_image( return model_output +class ZImageUnit_Image2LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_images",), + output_params=("image2lora_x",), + onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",), + ) + from ..core.data.operators import ImageCropAndResize + self.processor_highres = ImageCropAndResize(height=1024, width=1024) + + def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["siglip2_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["dinov3_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]): + if images is None: + return {} + if not isinstance(images, list): + images = [images] + embs_siglip2 = self.encode_images_using_siglip2(pipe, images) + embs_dinov3 = self.encode_images_using_dinov3(pipe, images) + x = torch.concat([embs_siglip2, embs_dinov3], dim=-1) + return x + + def process(self, pipe: ZImagePipeline, image2lora_images): + if image2lora_images is None: + return {} + x = self.encode_images(pipe, image2lora_images) + return {"image2lora_x": x} + + +class ZImageUnit_Image2LoRADecode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_x",), + output_params=("lora",), + onload_model_names=("image2lora_style",), + ) + + def process(self, pipe: ZImagePipeline, image2lora_x): + if image2lora_x is None: + return {} + loras = [] + if pipe.image2lora_style is not None: + pipe.load_models_to_device(["image2lora_style"]) + for x in image2lora_x: + loras.append(pipe.image2lora_style(x=x, residual=None)) + lora = merge_lora(loras, alpha=1 / len(image2lora_x)) + return {"lora": lora} + + def model_fn_z_image_turbo( dit: ZImageDiT, controlnet: ZImageControlNet = None, diff --git a/examples/z_image/model_inference/Z-Image-Omni-Base-i2L.py b/examples/z_image/model_inference/Z-Image-Omni-Base-i2L.py new file mode 100644 index 0000000..73e67d9 --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-Omni-Base-i2L.py @@ -0,0 +1,62 @@ +from diffsynth.pipelines.z_image import ( + ZImagePipeline, ModelConfig, + ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode +) +from modelscope import snapshot_download +from safetensors.torch import save_file +import torch +from PIL import Image + +# Use `vram_config` to enable LoRA hot-loading +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cuda", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +# Load models +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", origin_file_pattern="model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Load images +snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-i2L", + allow_file_pattern="assets/style/*", + local_dir="data/examples" +) +images = [Image.open(f"data/style/1/{i}.jpg") for i in range(5)] + +# Image to LoRA +with torch.no_grad(): + embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] +save_file(lora, "lora.safetensors") + +# Generate images +prompt = "a cat" +negative_prompt = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符" +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=0, cfg_scale=7, num_inference_steps=50, + positive_only_lora=lora, + sigma_shift=8 +) +image.save("image.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py new file mode 100644 index 0000000..62a7b31 --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py @@ -0,0 +1,62 @@ +from diffsynth.pipelines.z_image import ( + ZImagePipeline, ModelConfig, + ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode +) +from modelscope import snapshot_download +from safetensors.torch import save_file +import torch +from PIL import Image + +# Use `vram_config` to enable LoRA hot-loading +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +# Load models +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", origin_file_pattern="model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Load images +snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-i2L", + allow_file_pattern="assets/style/*", + local_dir="data/examples" +) +images = [Image.open(f"data/style/1/{i}.jpg") for i in range(5)] + +# Image to LoRA +with torch.no_grad(): + embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] +save_file(lora, "lora.safetensors") + +# Generate images +prompt = "a cat" +negative_prompt = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符" +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=0, cfg_scale=7, num_inference_steps=50, + positive_only_lora=lora, + sigma_shift=8 +) +image.save("image.jpg")