mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support z-image-omni-base-i2L
This commit is contained in:
@@ -4,12 +4,13 @@ from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from typing import Union, List, Optional, Tuple, Iterable
|
||||
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
||||
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..core.data.operators import ImageCropAndResize
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||
from ..utils.lora import merge_lora
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||
@@ -17,6 +18,9 @@ from ..models.z_image_dit import ZImageDiT
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
|
||||
from ..models.z_image_controlnet import ZImageControlNet
|
||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder
|
||||
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
|
||||
from ..models.z_image_image2lora import ZImageImage2LoRAModel
|
||||
|
||||
|
||||
class ZImagePipeline(BasePipeline):
|
||||
@@ -33,6 +37,9 @@ class ZImagePipeline(BasePipeline):
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.image_encoder: Siglip2ImageEncoder428M = None
|
||||
self.controlnet: ZImageControlNet = None
|
||||
self.siglip2_image_encoder: Siglip2ImageEncoder = None
|
||||
self.dinov3_image_encoder: DINOv3ImageEncoder = None
|
||||
self.image2lora_style: ZImageImage2LoRAModel = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
self.in_iteration_models = ("dit", "controlnet")
|
||||
self.units = [
|
||||
@@ -67,6 +74,9 @@ class ZImagePipeline(BasePipeline):
|
||||
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
||||
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
|
||||
pipe.controlnet = model_pool.fetch_model("z_image_controlnet")
|
||||
pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder")
|
||||
pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder")
|
||||
pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
@@ -100,6 +110,9 @@ class ZImagePipeline(BasePipeline):
|
||||
sigma_shift: float = None,
|
||||
# ControlNet
|
||||
controlnet_inputs: List[ControlNetInput] = None,
|
||||
# Image to LoRA
|
||||
image2lora_images: List[Image.Image] = None,
|
||||
positive_only_lora: Dict[str, torch.Tensor] = None,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
@@ -121,6 +134,7 @@ class ZImagePipeline(BasePipeline):
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
||||
"controlnet_inputs": controlnet_inputs,
|
||||
"image2lora_images": image2lora_images, "positive_only_lora": positive_only_lora,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -480,6 +494,71 @@ def model_fn_z_image(
|
||||
return model_output
|
||||
|
||||
|
||||
class ZImageUnit_Image2LoRAEncode(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("image2lora_images",),
|
||||
output_params=("image2lora_x",),
|
||||
onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",),
|
||||
)
|
||||
from ..core.data.operators import ImageCropAndResize
|
||||
self.processor_highres = ImageCropAndResize(height=1024, width=1024)
|
||||
|
||||
def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||
pipe.load_models_to_device(["siglip2_image_encoder"])
|
||||
embs = []
|
||||
for image in images:
|
||||
image = self.processor_highres(image)
|
||||
embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))
|
||||
embs = torch.stack(embs)
|
||||
return embs
|
||||
|
||||
def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||
pipe.load_models_to_device(["dinov3_image_encoder"])
|
||||
embs = []
|
||||
for image in images:
|
||||
image = self.processor_highres(image)
|
||||
embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))
|
||||
embs = torch.stack(embs)
|
||||
return embs
|
||||
|
||||
def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||
if images is None:
|
||||
return {}
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
embs_siglip2 = self.encode_images_using_siglip2(pipe, images)
|
||||
embs_dinov3 = self.encode_images_using_dinov3(pipe, images)
|
||||
x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)
|
||||
return x
|
||||
|
||||
def process(self, pipe: ZImagePipeline, image2lora_images):
|
||||
if image2lora_images is None:
|
||||
return {}
|
||||
x = self.encode_images(pipe, image2lora_images)
|
||||
return {"image2lora_x": x}
|
||||
|
||||
|
||||
class ZImageUnit_Image2LoRADecode(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("image2lora_x",),
|
||||
output_params=("lora",),
|
||||
onload_model_names=("image2lora_style",),
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, image2lora_x):
|
||||
if image2lora_x is None:
|
||||
return {}
|
||||
loras = []
|
||||
if pipe.image2lora_style is not None:
|
||||
pipe.load_models_to_device(["image2lora_style"])
|
||||
for x in image2lora_x:
|
||||
loras.append(pipe.image2lora_style(x=x, residual=None))
|
||||
lora = merge_lora(loras, alpha=1 / len(image2lora_x))
|
||||
return {"lora": lora}
|
||||
|
||||
|
||||
def model_fn_z_image_turbo(
|
||||
dit: ZImageDiT,
|
||||
controlnet: ZImageControlNet = None,
|
||||
|
||||
Reference in New Issue
Block a user