Files
DiffSynth-Studio/diffsynth/pipelines/z_image.py
2026-01-07 15:56:53 +08:00

586 lines
22 KiB
Python

import torch, math
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple, Iterable
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from transformers import AutoTokenizer
from ..models.z_image_text_encoder import ZImageTextEncoder
from ..models.z_image_dit import ZImageDiT
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
from ..models.z_image_controlnet import ZImageControlNet
class ZImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("Z-Image")
self.text_encoder: ZImageTextEncoder = None
self.dit: ZImageDiT = None
self.vae_encoder: FluxVAEEncoder = None
self.vae_decoder: FluxVAEDecoder = None
self.image_encoder: Siglip2ImageEncoder428M = None
self.controlnet: ZImageControlNet = None
self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("dit", "controlnet")
self.units = [
ZImageUnit_ShapeChecker(),
ZImageUnit_PromptEmbedder(),
ZImageUnit_NoiseInitializer(),
ZImageUnit_InputImageEmbedder(),
ZImageUnit_EditImageAutoResize(),
ZImageUnit_EditImageEmbedderVAE(),
ZImageUnit_EditImageEmbedderSiglip(),
ZImageUnit_PAIControlNet(),
]
self.model_fn = model_fn_z_image
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
):
# Initialize pipeline
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("z_image_dit")
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
pipe.controlnet = model_pool.fetch_model("z_image_controlnet")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Edit
edit_image: Image.Image = None,
edit_image_auto_resize: bool = True,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 8,
sigma_shift: float = None,
# ControlNet
controlnet_inputs: List[ControlNetInput] = None,
# Progress bar
progress_bar_cmd = tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Parameters
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
"controlnet_inputs": controlnet_inputs,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae_decoder'])
image = self.vae_decoder(inputs_shared["latents"])
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class ZImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: ZImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class ZImageUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params=("edit_image",),
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
def encode_prompt(
self,
pipe,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
if isinstance(prompt, str):
prompt = [prompt]
for i, prompt_item in enumerate(prompt):
messages = [
{"role": "user", "content": prompt_item},
]
prompt_item = pipe.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
prompt[i] = prompt_item
text_inputs = pipe.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = pipe.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
for i in range(len(prompt_embeds)):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
return embeddings_list
def encode_prompt_omni(
self,
pipe,
prompt: Union[str, List[str]],
edit_image=None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
if isinstance(prompt, str):
prompt = [prompt]
if edit_image is None:
num_condition_images = 0
elif isinstance(edit_image, list):
num_condition_images = len(edit_image)
else:
num_condition_images = 1
for i, prompt_item in enumerate(prompt):
if num_condition_images == 0:
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
elif num_condition_images > 0:
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|im_end|>"]
prompt[i] = prompt_list
flattened_prompt = []
prompt_list_lengths = []
for i in range(len(prompt)):
prompt_list_lengths.append(len(prompt[i]))
flattened_prompt.extend(prompt[i])
text_inputs = pipe.tokenizer(
flattened_prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = pipe.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
start_idx = 0
for i in range(len(prompt_list_lengths)):
batch_embeddings = []
end_idx = start_idx + prompt_list_lengths[i]
for j in range(start_idx, end_idx):
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
embeddings_list.append(batch_embeddings)
start_idx = end_idx
return embeddings_list
def process(self, pipe: ZImagePipeline, prompt, edit_image):
pipe.load_models_to_device(self.onload_model_names)
if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
# We determine which encoding method to use based on the model architecture.
# If you are using two-stage split training,
# please use `--offload_models` instead of skipping the DiT model loading.
prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device)
else:
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_embeds": prompt_embeds}
class ZImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: ZImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class ZImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: ZImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image)
input_latents = pipe.vae_encoder(image)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
class ZImageUnit_EditImageAutoResize(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "edit_image_auto_resize"),
output_params=("edit_image",),
)
def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize):
if edit_image is None:
return {}
if edit_image_auto_resize is None or not edit_image_auto_resize:
return {}
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
if not isinstance(edit_image, list):
edit_image = [edit_image]
edit_image = [operator(i) for i in edit_image]
return {"edit_image": edit_image}
class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image",),
output_params=("image_embeds",),
onload_model_names=("image_encoder",)
)
def process(self, pipe: ZImagePipeline, edit_image):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if not isinstance(edit_image, list):
edit_image = [edit_image]
image_emb = []
for image_ in edit_image:
image_emb.append(pipe.image_encoder(image_, device=pipe.device))
return {"image_embeds": image_emb}
class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image",),
output_params=("image_latents",),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: ZImagePipeline, edit_image):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if not isinstance(edit_image, list):
edit_image = [edit_image]
image_latents = []
for image_ in edit_image:
image_ = pipe.preprocess_image(image_)
image_latents.append(pipe.vae_encoder(image_))
return {"image_latents": image_latents}
class ZImageUnit_PAIControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("controlnet_inputs", "height", "width"),
output_params=("control_context", "control_scale"),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width):
if controlnet_inputs is None:
return {}
if len(controlnet_inputs) != 1:
print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.")
controlnet_input = controlnet_inputs[0]
pipe.load_models_to_device(self.onload_model_names)
control_image = controlnet_input.image
if control_image is not None:
control_image = pipe.preprocess_image(control_image)
control_latents = pipe.vae_encoder(control_image)
else:
control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1
inpaint_mask = controlnet_input.inpaint_mask
if inpaint_mask is not None:
inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1)
inpaint_image = controlnet_input.inpaint_image
inpaint_image = pipe.preprocess_image(inpaint_image)
inpaint_image = inpaint_image * (inpaint_mask < 0.5)
inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1]
else:
inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device)
inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device)
inpaint_latent = pipe.vae_encoder(inpaint_image)
control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1)
control_context = rearrange(control_context, "B C H W -> B C 1 H W")
return {"control_context": control_context, "control_scale": controlnet_input.scale}
def model_fn_z_image(
dit: ZImageDiT,
controlnet: ZImageControlNet = None,
latents=None,
timestep=None,
prompt_embeds=None,
image_embeds=None,
image_latents=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
# Due to the complex and verbose codebase of Z-Image,
# we are temporarily using this inelegant structure.
# We will refactor this part in the future (if time permits).
if dit.siglip_embedder is None:
return model_fn_z_image_turbo(
dit,
controlnet=controlnet,
latents=latents,
timestep=timestep,
prompt_embeds=prompt_embeds,
image_embeds=image_embeds,
image_latents=image_latents,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
**kwargs,
)
latents = [rearrange(latents, "B C H W -> C B H W")]
if dit.siglip_embedder is not None:
if image_latents is not None:
image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents]
latents = [image_latents + latents]
image_noise_mask = [[0] * len(image_latents) + [1]]
else:
latents = [latents]
image_noise_mask = [[1]]
image_embeds = [image_embeds]
else:
image_noise_mask = None
timestep = (1000 - timestep) / 1000
model_output = dit(
latents,
timestep,
prompt_embeds,
siglip_feats=image_embeds,
image_noise_mask=image_noise_mask,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)[0]
model_output = -model_output
model_output = rearrange(model_output, "C B H W -> B C H W")
return model_output
def model_fn_z_image_turbo(
dit: ZImageDiT,
controlnet: ZImageControlNet = None,
latents=None,
timestep=None,
prompt_embeds=None,
image_embeds=None,
image_latents=None,
control_context=None,
control_scale=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
while isinstance(prompt_embeds, list):
prompt_embeds = prompt_embeds[0]
while isinstance(latents, list):
latents = latents[0]
while isinstance(image_embeds, list):
image_embeds = image_embeds[0]
# Timestep
timestep = 1000 - timestep
t_noisy = dit.t_embedder(timestep)
t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000)
# Patchify
latents = rearrange(latents, "B C H W -> C B H W")
x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds])
x = x[0]
cap_feats = cap_feats[0]
# Noise refine
x = dit.all_x_embedder["2-1"](x)
x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device)
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
x = rearrange(x, "L C -> 1 L C")
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
if control_context is not None:
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1)
for layer_id, layer in enumerate(dit.noise_refiner):
x = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=x,
attn_mask=None,
freqs_cis=x_freqs_cis,
adaln_input=t_noisy,
)
if control_context is not None:
x = x + refiner_hints[layer_id] * control_scale
# Prompt refine
cap_feats = dit.cap_embedder(cap_feats)
cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device)
cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0))
cap_feats = rearrange(cap_feats, "L C -> 1 L C")
cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C")
for layer in dit.context_refiner:
cap_feats = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=cap_feats,
attn_mask=None,
freqs_cis=cap_freqs_cis,
)
# Unified
unified = torch.cat([x, cap_feats], dim=1)
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
if control_context is not None:
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
hints = controlnet.forward_layers(
unified, cap_feats, control_context, control_context_item_seqlens, kwargs)
for layer_id, layer in enumerate(dit.layers):
unified = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=unified,
attn_mask=None,
freqs_cis=unified_freqs_cis,
adaln_input=t_noisy,
)
if control_context is not None:
if layer_id in controlnet.control_layers_mapping:
unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale
# Output
unified = dit.all_final_layer["2-1"](unified, t_noisy)
x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0]
x = rearrange(x, "C B H W -> B C H W")
x = -x
return x