support z-image-omni-base

This commit is contained in:
Artiprocher
2026-01-05 14:45:01 +08:00
parent ab8580f77e
commit 5745c9f200
4 changed files with 925 additions and 129 deletions

View File

@@ -4,16 +4,18 @@ from typing import Union
from tqdm import tqdm
from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple
from typing import Union, List, Optional, Tuple, Iterable
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from transformers import AutoTokenizer
from ..models.z_image_text_encoder import ZImageTextEncoder
from ..models.z_image_dit import ZImageDiT
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
class ZImagePipeline(BasePipeline):
@@ -28,6 +30,7 @@ class ZImagePipeline(BasePipeline):
self.dit: ZImageDiT = None
self.vae_encoder: FluxVAEEncoder = None
self.vae_decoder: FluxVAEDecoder = None
self.image_encoder: Siglip2ImageEncoder428M = None
self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("dit",)
self.units = [
@@ -35,6 +38,9 @@ class ZImagePipeline(BasePipeline):
ZImageUnit_PromptEmbedder(),
ZImageUnit_NoiseInitializer(),
ZImageUnit_InputImageEmbedder(),
ZImageUnit_EditImageAutoResize(),
ZImageUnit_EditImageEmbedderVAE(),
ZImageUnit_EditImageEmbedderSiglip(),
]
self.model_fn = model_fn_z_image
@@ -56,6 +62,7 @@ class ZImagePipeline(BasePipeline):
pipe.dit = model_pool.fetch_model("z_image_dit")
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
@@ -75,6 +82,9 @@ class ZImagePipeline(BasePipeline):
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Edit
edit_image: Image.Image = None,
edit_image_auto_resize: bool = True,
# Shape
height: int = 1024,
width: int = 1024,
@@ -83,11 +93,12 @@ class ZImagePipeline(BasePipeline):
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 8,
sigma_shift: float = None,
# Progress bar
progress_bar_cmd = tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Parameters
inputs_posi = {
@@ -102,6 +113,7 @@ class ZImagePipeline(BasePipeline):
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -143,12 +155,13 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params=("edit_image",),
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
def encode_prompt(
self,
pipe,
@@ -194,10 +207,81 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
return embeddings_list
def encode_prompt_omni(
self,
pipe,
prompt: Union[str, List[str]],
edit_image=None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
if isinstance(prompt, str):
prompt = [prompt]
def process(self, pipe: ZImagePipeline, prompt):
if edit_image is None:
num_condition_images = 0
elif isinstance(edit_image, list):
num_condition_images = len(edit_image)
else:
num_condition_images = 1
for i, prompt_item in enumerate(prompt):
if num_condition_images == 0:
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
elif num_condition_images > 0:
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|im_end|>"]
prompt[i] = prompt_list
flattened_prompt = []
prompt_list_lengths = []
for i in range(len(prompt)):
prompt_list_lengths.append(len(prompt[i]))
flattened_prompt.extend(prompt[i])
text_inputs = pipe.tokenizer(
flattened_prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = pipe.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
start_idx = 0
for i in range(len(prompt_list_lengths)):
batch_embeddings = []
end_idx = start_idx + prompt_list_lengths[i]
for j in range(start_idx, end_idx):
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
embeddings_list.append(batch_embeddings)
start_idx = end_idx
return embeddings_list
def process(self, pipe: ZImagePipeline, prompt, edit_image):
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
# We determine which encoding method to use based on the model architecture.
# If you are using two-stage split training,
# please use `--offload_models` instead of skipping the DiT model loading.
prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device)
else:
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_embeds": prompt_embeds}
@@ -234,24 +318,197 @@ class ZImageUnit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": input_latents}
class ZImageUnit_EditImageAutoResize(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "edit_image_auto_resize"),
output_params=("edit_image",),
)
def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize):
if edit_image is None:
return {}
if edit_image_auto_resize is None or not edit_image_auto_resize:
return {}
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
edit_image = operator(edit_image)
return {"edit_image": edit_image}
class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image",),
output_params=("image_embeds",),
onload_model_names=("image_encoder",)
)
def process(self, pipe: ZImagePipeline, edit_image):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if not isinstance(edit_image, list):
edit_image = [edit_image]
image_emb = []
for image_ in edit_image:
image_emb.append(pipe.image_encoder(image_, device=pipe.device))
return {"image_embeds": image_emb}
class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image",),
output_params=("image_latents",),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: ZImagePipeline, edit_image):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if not isinstance(edit_image, list):
edit_image = [edit_image]
image_latents = []
for image_ in edit_image:
image_ = pipe.preprocess_image(image_)
image_latents.append(pipe.vae_encoder(image_))
return {"image_latents": image_latents}
def model_fn_z_image(
dit: ZImageDiT,
latents=None,
timestep=None,
prompt_embeds=None,
image_embeds=None,
image_latents=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
# Due to the complex and verbose codebase of Z-Image,
# we are temporarily using this inelegant structure.
# We will refactor this part in the future (if time permits).
if dit.siglip_embedder is None:
return model_fn_z_image_turbo(
dit,
latents,
timestep,
prompt_embeds,
image_embeds,
image_latents,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
**kwargs,
)
latents = [rearrange(latents, "B C H W -> C B H W")]
if dit.siglip_embedder is not None:
if image_latents is not None:
image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents]
latents = [image_latents + latents]
image_noise_mask = [[0] * len(image_latents) + [1]]
else:
latents = [latents]
image_noise_mask = [[1]]
image_embeds = [image_embeds]
else:
image_noise_mask = None
timestep = (1000 - timestep) / 1000
model_output = dit(
latents,
timestep,
prompt_embeds,
siglip_feats=image_embeds,
image_noise_mask=image_noise_mask,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)[0][0]
)[0]
model_output = -model_output
model_output = rearrange(model_output, "C B H W -> B C H W")
return model_output
def model_fn_z_image_turbo(
dit: ZImageDiT,
latents=None,
timestep=None,
prompt_embeds=None,
image_embeds=None,
image_latents=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
while isinstance(prompt_embeds, list):
prompt_embeds = prompt_embeds[0]
while isinstance(latents, list):
latents = latents[0]
while isinstance(image_embeds, list):
image_embeds = image_embeds[0]
# Timestep
timestep = 1000 - timestep
t_noisy = dit.t_embedder(timestep)
t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000)
# Patchify
latents = rearrange(latents, "B C H W -> C B H W")
x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds])
x = x[0]
cap_feats = cap_feats[0]
# Noise refine
x = dit.all_x_embedder["2-1"](x)
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
x = rearrange(x, "L C -> 1 L C")
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
for layer in dit.noise_refiner:
x = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=x,
attn_mask=None,
freqs_cis=x_freqs_cis,
adaln_input=t_noisy,
)
# Prompt refine
cap_feats = dit.cap_embedder(cap_feats)
cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device)
cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0))
cap_feats = rearrange(cap_feats, "L C -> 1 L C")
cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C")
for layer in dit.context_refiner:
cap_feats = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=cap_feats,
attn_mask=None,
freqs_cis=cap_freqs_cis,
)
# Unified
unified = torch.cat([x, cap_feats], dim=1)
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
for layer in dit.layers:
unified = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=unified,
attn_mask=None,
freqs_cis=unified_freqs_cis,
adaln_input=t_noisy,
)
# Output
unified = dit.all_final_layer["2-1"](unified, t_noisy)
x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0]
x = rearrange(x, "C B H W -> B C H W")
x = -x
return x