mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
665 lines
26 KiB
Python
665 lines
26 KiB
Python
import torch, math
|
|
from PIL import Image
|
|
from typing import Union
|
|
from tqdm import tqdm
|
|
from einops import rearrange
|
|
import numpy as np
|
|
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
|
|
|
from ..diffusion import FlowMatchScheduler
|
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
|
from ..core.data.operators import ImageCropAndResize
|
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
|
from ..utils.lora import merge_lora
|
|
|
|
from transformers import AutoTokenizer
|
|
from ..models.z_image_text_encoder import ZImageTextEncoder
|
|
from ..models.z_image_dit import ZImageDiT
|
|
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
|
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
|
|
from ..models.z_image_controlnet import ZImageControlNet
|
|
from ..models.siglip2_image_encoder import Siglip2ImageEncoder
|
|
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
|
|
from ..models.z_image_image2lora import ZImageImage2LoRAModel
|
|
|
|
|
|
class ZImagePipeline(BasePipeline):
|
|
|
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
|
super().__init__(
|
|
device=device, torch_dtype=torch_dtype,
|
|
height_division_factor=16, width_division_factor=16,
|
|
)
|
|
self.scheduler = FlowMatchScheduler("Z-Image")
|
|
self.text_encoder: ZImageTextEncoder = None
|
|
self.dit: ZImageDiT = None
|
|
self.vae_encoder: FluxVAEEncoder = None
|
|
self.vae_decoder: FluxVAEDecoder = None
|
|
self.image_encoder: Siglip2ImageEncoder428M = None
|
|
self.controlnet: ZImageControlNet = None
|
|
self.siglip2_image_encoder: Siglip2ImageEncoder = None
|
|
self.dinov3_image_encoder: DINOv3ImageEncoder = None
|
|
self.image2lora_style: ZImageImage2LoRAModel = None
|
|
self.tokenizer: AutoTokenizer = None
|
|
self.in_iteration_models = ("dit", "controlnet")
|
|
self.units = [
|
|
ZImageUnit_ShapeChecker(),
|
|
ZImageUnit_PromptEmbedder(),
|
|
ZImageUnit_NoiseInitializer(),
|
|
ZImageUnit_InputImageEmbedder(),
|
|
ZImageUnit_EditImageAutoResize(),
|
|
ZImageUnit_EditImageEmbedderVAE(),
|
|
ZImageUnit_EditImageEmbedderSiglip(),
|
|
ZImageUnit_PAIControlNet(),
|
|
]
|
|
self.model_fn = model_fn_z_image
|
|
|
|
|
|
@staticmethod
|
|
def from_pretrained(
|
|
torch_dtype: torch.dtype = torch.bfloat16,
|
|
device: Union[str, torch.device] = "cuda",
|
|
model_configs: list[ModelConfig] = [],
|
|
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
vram_limit: float = None,
|
|
):
|
|
# Initialize pipeline
|
|
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
|
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
|
|
|
# Fetch models
|
|
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
|
pipe.dit = model_pool.fetch_model("z_image_dit")
|
|
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
|
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
|
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
|
|
pipe.controlnet = model_pool.fetch_model("z_image_controlnet")
|
|
pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder")
|
|
pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder")
|
|
pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style")
|
|
if tokenizer_config is not None:
|
|
tokenizer_config.download_if_necessary()
|
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
|
|
|
# VRAM Management
|
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
|
return pipe
|
|
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
# Prompt
|
|
prompt: str,
|
|
negative_prompt: str = "",
|
|
cfg_scale: float = 1.0,
|
|
# Image
|
|
input_image: Image.Image = None,
|
|
denoising_strength: float = 1.0,
|
|
# Edit
|
|
edit_image: Image.Image = None,
|
|
edit_image_auto_resize: bool = True,
|
|
# Shape
|
|
height: int = 1024,
|
|
width: int = 1024,
|
|
# Randomness
|
|
seed: int = None,
|
|
rand_device: str = "cpu",
|
|
# Steps
|
|
num_inference_steps: int = 8,
|
|
sigma_shift: float = None,
|
|
# ControlNet
|
|
controlnet_inputs: List[ControlNetInput] = None,
|
|
# Image to LoRA
|
|
image2lora_images: List[Image.Image] = None,
|
|
positive_only_lora: Dict[str, torch.Tensor] = None,
|
|
# Progress bar
|
|
progress_bar_cmd = tqdm,
|
|
):
|
|
# Scheduler
|
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
|
|
|
# Parameters
|
|
inputs_posi = {
|
|
"prompt": prompt,
|
|
}
|
|
inputs_nega = {
|
|
"negative_prompt": negative_prompt,
|
|
}
|
|
inputs_shared = {
|
|
"cfg_scale": cfg_scale,
|
|
"input_image": input_image, "denoising_strength": denoising_strength,
|
|
"height": height, "width": width,
|
|
"seed": seed, "rand_device": rand_device,
|
|
"num_inference_steps": num_inference_steps,
|
|
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
|
"controlnet_inputs": controlnet_inputs,
|
|
"image2lora_images": image2lora_images, "positive_only_lora": positive_only_lora,
|
|
}
|
|
for unit in self.units:
|
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
|
|
|
# Denoise
|
|
self.load_models_to_device(self.in_iteration_models)
|
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
noise_pred = self.cfg_guided_model_fn(
|
|
self.model_fn, cfg_scale,
|
|
inputs_shared, inputs_posi, inputs_nega,
|
|
**models, timestep=timestep, progress_id=progress_id
|
|
)
|
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
|
|
|
# Decode
|
|
self.load_models_to_device(['vae_decoder'])
|
|
image = self.vae_decoder(inputs_shared["latents"])
|
|
image = self.vae_output_to_image(image)
|
|
self.load_models_to_device([])
|
|
|
|
return image
|
|
|
|
|
|
class ZImageUnit_ShapeChecker(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("height", "width"),
|
|
output_params=("height", "width"),
|
|
)
|
|
|
|
def process(self, pipe: ZImagePipeline, height, width):
|
|
height, width = pipe.check_resize_height_width(height, width)
|
|
return {"height": height, "width": width}
|
|
|
|
|
|
class ZImageUnit_PromptEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
seperate_cfg=True,
|
|
input_params=("edit_image",),
|
|
input_params_posi={"prompt": "prompt"},
|
|
input_params_nega={"prompt": "negative_prompt"},
|
|
output_params=("prompt_embeds",),
|
|
onload_model_names=("text_encoder",)
|
|
)
|
|
|
|
def encode_prompt(
|
|
self,
|
|
pipe,
|
|
prompt: Union[str, List[str]],
|
|
device: Optional[torch.device] = None,
|
|
max_sequence_length: int = 512,
|
|
) -> List[torch.FloatTensor]:
|
|
if isinstance(prompt, str):
|
|
prompt = [prompt]
|
|
|
|
for i, prompt_item in enumerate(prompt):
|
|
messages = [
|
|
{"role": "user", "content": prompt_item},
|
|
]
|
|
prompt_item = pipe.tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
enable_thinking=True,
|
|
)
|
|
prompt[i] = prompt_item
|
|
|
|
text_inputs = pipe.tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=max_sequence_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
text_input_ids = text_inputs.input_ids.to(device)
|
|
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
|
|
|
prompt_embeds = pipe.text_encoder(
|
|
input_ids=text_input_ids,
|
|
attention_mask=prompt_masks,
|
|
output_hidden_states=True,
|
|
).hidden_states[-2]
|
|
|
|
embeddings_list = []
|
|
|
|
for i in range(len(prompt_embeds)):
|
|
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
|
|
|
return embeddings_list
|
|
|
|
def encode_prompt_omni(
|
|
self,
|
|
pipe,
|
|
prompt: Union[str, List[str]],
|
|
edit_image=None,
|
|
device: Optional[torch.device] = None,
|
|
max_sequence_length: int = 512,
|
|
) -> List[torch.FloatTensor]:
|
|
if isinstance(prompt, str):
|
|
prompt = [prompt]
|
|
|
|
if edit_image is None:
|
|
num_condition_images = 0
|
|
elif isinstance(edit_image, list):
|
|
num_condition_images = len(edit_image)
|
|
else:
|
|
num_condition_images = 1
|
|
|
|
for i, prompt_item in enumerate(prompt):
|
|
if num_condition_images == 0:
|
|
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
|
|
elif num_condition_images > 0:
|
|
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
|
|
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
|
|
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
|
|
prompt_list += ["<|vision_end|><|im_end|>"]
|
|
prompt[i] = prompt_list
|
|
|
|
flattened_prompt = []
|
|
prompt_list_lengths = []
|
|
|
|
for i in range(len(prompt)):
|
|
prompt_list_lengths.append(len(prompt[i]))
|
|
flattened_prompt.extend(prompt[i])
|
|
|
|
text_inputs = pipe.tokenizer(
|
|
flattened_prompt,
|
|
padding="max_length",
|
|
max_length=max_sequence_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
text_input_ids = text_inputs.input_ids.to(device)
|
|
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
|
|
|
prompt_embeds = pipe.text_encoder(
|
|
input_ids=text_input_ids,
|
|
attention_mask=prompt_masks,
|
|
output_hidden_states=True,
|
|
).hidden_states[-2]
|
|
|
|
embeddings_list = []
|
|
start_idx = 0
|
|
for i in range(len(prompt_list_lengths)):
|
|
batch_embeddings = []
|
|
end_idx = start_idx + prompt_list_lengths[i]
|
|
for j in range(start_idx, end_idx):
|
|
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
|
|
embeddings_list.append(batch_embeddings)
|
|
start_idx = end_idx
|
|
|
|
return embeddings_list
|
|
|
|
def process(self, pipe: ZImagePipeline, prompt, edit_image):
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
|
|
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
|
|
# We determine which encoding method to use based on the model architecture.
|
|
# If you are using two-stage split training,
|
|
# please use `--offload_models` instead of skipping the DiT model loading.
|
|
prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device)
|
|
else:
|
|
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
|
return {"prompt_embeds": prompt_embeds}
|
|
|
|
|
|
class ZImageUnit_NoiseInitializer(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("height", "width", "seed", "rand_device"),
|
|
output_params=("noise",),
|
|
)
|
|
|
|
def process(self, pipe: ZImagePipeline, height, width, seed, rand_device):
|
|
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
|
return {"noise": noise}
|
|
|
|
|
|
class ZImageUnit_InputImageEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("input_image", "noise"),
|
|
output_params=("latents", "input_latents"),
|
|
onload_model_names=("vae_encoder",)
|
|
)
|
|
|
|
def process(self, pipe: ZImagePipeline, input_image, noise):
|
|
if input_image is None:
|
|
return {"latents": noise, "input_latents": None}
|
|
pipe.load_models_to_device(['vae'])
|
|
image = pipe.preprocess_image(input_image)
|
|
input_latents = pipe.vae_encoder(image)
|
|
if pipe.scheduler.training:
|
|
return {"latents": noise, "input_latents": input_latents}
|
|
else:
|
|
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
|
return {"latents": latents, "input_latents": input_latents}
|
|
|
|
|
|
class ZImageUnit_EditImageAutoResize(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("edit_image", "edit_image_auto_resize"),
|
|
output_params=("edit_image",),
|
|
)
|
|
|
|
def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize):
|
|
if edit_image is None:
|
|
return {}
|
|
if edit_image_auto_resize is None or not edit_image_auto_resize:
|
|
return {}
|
|
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
|
|
if not isinstance(edit_image, list):
|
|
edit_image = [edit_image]
|
|
edit_image = [operator(i) for i in edit_image]
|
|
return {"edit_image": edit_image}
|
|
|
|
|
|
class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("edit_image",),
|
|
output_params=("image_embeds",),
|
|
onload_model_names=("image_encoder",)
|
|
)
|
|
|
|
def process(self, pipe: ZImagePipeline, edit_image):
|
|
if edit_image is None:
|
|
return {}
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
if not isinstance(edit_image, list):
|
|
edit_image = [edit_image]
|
|
image_emb = []
|
|
for image_ in edit_image:
|
|
image_emb.append(pipe.image_encoder(image_, device=pipe.device))
|
|
return {"image_embeds": image_emb}
|
|
|
|
|
|
class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("edit_image",),
|
|
output_params=("image_latents",),
|
|
onload_model_names=("vae_encoder",)
|
|
)
|
|
|
|
def process(self, pipe: ZImagePipeline, edit_image):
|
|
if edit_image is None:
|
|
return {}
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
if not isinstance(edit_image, list):
|
|
edit_image = [edit_image]
|
|
image_latents = []
|
|
for image_ in edit_image:
|
|
image_ = pipe.preprocess_image(image_)
|
|
image_latents.append(pipe.vae_encoder(image_))
|
|
return {"image_latents": image_latents}
|
|
|
|
|
|
class ZImageUnit_PAIControlNet(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("controlnet_inputs", "height", "width"),
|
|
output_params=("control_context", "control_scale"),
|
|
onload_model_names=("vae_encoder",)
|
|
)
|
|
|
|
def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width):
|
|
if controlnet_inputs is None:
|
|
return {}
|
|
if len(controlnet_inputs) != 1:
|
|
print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.")
|
|
controlnet_input = controlnet_inputs[0]
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
|
|
control_image = controlnet_input.image
|
|
if control_image is not None:
|
|
control_image = pipe.preprocess_image(control_image)
|
|
control_latents = pipe.vae_encoder(control_image)
|
|
else:
|
|
control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1
|
|
|
|
inpaint_mask = controlnet_input.inpaint_mask
|
|
if inpaint_mask is not None:
|
|
inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1)
|
|
inpaint_image = controlnet_input.inpaint_image
|
|
inpaint_image = pipe.preprocess_image(inpaint_image)
|
|
inpaint_image = inpaint_image * (inpaint_mask < 0.5)
|
|
inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1]
|
|
else:
|
|
inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device)
|
|
inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device)
|
|
inpaint_latent = pipe.vae_encoder(inpaint_image)
|
|
|
|
control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1)
|
|
control_context = rearrange(control_context, "B C H W -> B C 1 H W")
|
|
return {"control_context": control_context, "control_scale": controlnet_input.scale}
|
|
|
|
|
|
def model_fn_z_image(
|
|
dit: ZImageDiT,
|
|
controlnet: ZImageControlNet = None,
|
|
latents=None,
|
|
timestep=None,
|
|
prompt_embeds=None,
|
|
image_embeds=None,
|
|
image_latents=None,
|
|
use_gradient_checkpointing=False,
|
|
use_gradient_checkpointing_offload=False,
|
|
**kwargs,
|
|
):
|
|
# Due to the complex and verbose codebase of Z-Image,
|
|
# we are temporarily using this inelegant structure.
|
|
# We will refactor this part in the future (if time permits).
|
|
if dit.siglip_embedder is None:
|
|
return model_fn_z_image_turbo(
|
|
dit,
|
|
controlnet=controlnet,
|
|
latents=latents,
|
|
timestep=timestep,
|
|
prompt_embeds=prompt_embeds,
|
|
image_embeds=image_embeds,
|
|
image_latents=image_latents,
|
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
**kwargs,
|
|
)
|
|
latents = [rearrange(latents, "B C H W -> C B H W")]
|
|
if dit.siglip_embedder is not None:
|
|
if image_latents is not None:
|
|
image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents]
|
|
latents = [image_latents + latents]
|
|
image_noise_mask = [[0] * len(image_latents) + [1]]
|
|
else:
|
|
latents = [latents]
|
|
image_noise_mask = [[1]]
|
|
image_embeds = [image_embeds]
|
|
else:
|
|
image_noise_mask = None
|
|
timestep = (1000 - timestep) / 1000
|
|
model_output = dit(
|
|
latents,
|
|
timestep,
|
|
prompt_embeds,
|
|
siglip_feats=image_embeds,
|
|
image_noise_mask=image_noise_mask,
|
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
)[0]
|
|
model_output = -model_output
|
|
model_output = rearrange(model_output, "C B H W -> B C H W")
|
|
return model_output
|
|
|
|
|
|
class ZImageUnit_Image2LoRAEncode(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("image2lora_images",),
|
|
output_params=("image2lora_x",),
|
|
onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",),
|
|
)
|
|
from ..core.data.operators import ImageCropAndResize
|
|
self.processor_highres = ImageCropAndResize(height=1024, width=1024)
|
|
|
|
def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
|
pipe.load_models_to_device(["siglip2_image_encoder"])
|
|
embs = []
|
|
for image in images:
|
|
image = self.processor_highres(image)
|
|
embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))
|
|
embs = torch.stack(embs)
|
|
return embs
|
|
|
|
def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
|
pipe.load_models_to_device(["dinov3_image_encoder"])
|
|
embs = []
|
|
for image in images:
|
|
image = self.processor_highres(image)
|
|
embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))
|
|
embs = torch.stack(embs)
|
|
return embs
|
|
|
|
def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
|
if images is None:
|
|
return {}
|
|
if not isinstance(images, list):
|
|
images = [images]
|
|
embs_siglip2 = self.encode_images_using_siglip2(pipe, images)
|
|
embs_dinov3 = self.encode_images_using_dinov3(pipe, images)
|
|
x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)
|
|
return x
|
|
|
|
def process(self, pipe: ZImagePipeline, image2lora_images):
|
|
if image2lora_images is None:
|
|
return {}
|
|
x = self.encode_images(pipe, image2lora_images)
|
|
return {"image2lora_x": x}
|
|
|
|
|
|
class ZImageUnit_Image2LoRADecode(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("image2lora_x",),
|
|
output_params=("lora",),
|
|
onload_model_names=("image2lora_style",),
|
|
)
|
|
|
|
def process(self, pipe: ZImagePipeline, image2lora_x):
|
|
if image2lora_x is None:
|
|
return {}
|
|
loras = []
|
|
if pipe.image2lora_style is not None:
|
|
pipe.load_models_to_device(["image2lora_style"])
|
|
for x in image2lora_x:
|
|
loras.append(pipe.image2lora_style(x=x, residual=None))
|
|
lora = merge_lora(loras, alpha=1 / len(image2lora_x))
|
|
return {"lora": lora}
|
|
|
|
|
|
def model_fn_z_image_turbo(
|
|
dit: ZImageDiT,
|
|
controlnet: ZImageControlNet = None,
|
|
latents=None,
|
|
timestep=None,
|
|
prompt_embeds=None,
|
|
image_embeds=None,
|
|
image_latents=None,
|
|
control_context=None,
|
|
control_scale=None,
|
|
use_gradient_checkpointing=False,
|
|
use_gradient_checkpointing_offload=False,
|
|
**kwargs,
|
|
):
|
|
while isinstance(prompt_embeds, list):
|
|
prompt_embeds = prompt_embeds[0]
|
|
while isinstance(latents, list):
|
|
latents = latents[0]
|
|
while isinstance(image_embeds, list):
|
|
image_embeds = image_embeds[0]
|
|
|
|
# Timestep
|
|
timestep = 1000 - timestep
|
|
t_noisy = dit.t_embedder(timestep)
|
|
t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000)
|
|
|
|
# Patchify
|
|
latents = rearrange(latents, "B C H W -> C B H W")
|
|
x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds])
|
|
x = x[0]
|
|
cap_feats = cap_feats[0]
|
|
|
|
# Noise refine
|
|
x = dit.all_x_embedder["2-1"](x)
|
|
x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device)
|
|
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
|
|
x = rearrange(x, "L C -> 1 L C")
|
|
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
|
|
|
|
if control_context is not None:
|
|
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
|
|
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
|
|
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1)
|
|
|
|
for layer_id, layer in enumerate(dit.noise_refiner):
|
|
x = gradient_checkpoint_forward(
|
|
layer,
|
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
x=x,
|
|
attn_mask=None,
|
|
freqs_cis=x_freqs_cis,
|
|
adaln_input=t_noisy,
|
|
)
|
|
if control_context is not None:
|
|
x = x + refiner_hints[layer_id] * control_scale
|
|
|
|
# Prompt refine
|
|
cap_feats = dit.cap_embedder(cap_feats)
|
|
cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device)
|
|
cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0))
|
|
cap_feats = rearrange(cap_feats, "L C -> 1 L C")
|
|
cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C")
|
|
|
|
for layer in dit.context_refiner:
|
|
cap_feats = gradient_checkpoint_forward(
|
|
layer,
|
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
x=cap_feats,
|
|
attn_mask=None,
|
|
freqs_cis=cap_freqs_cis,
|
|
)
|
|
|
|
# Unified
|
|
unified = torch.cat([x, cap_feats], dim=1)
|
|
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
|
|
|
|
if control_context is not None:
|
|
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
|
|
hints = controlnet.forward_layers(
|
|
unified, cap_feats, control_context, control_context_item_seqlens, kwargs)
|
|
|
|
for layer_id, layer in enumerate(dit.layers):
|
|
unified = gradient_checkpoint_forward(
|
|
layer,
|
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
x=unified,
|
|
attn_mask=None,
|
|
freqs_cis=unified_freqs_cis,
|
|
adaln_input=t_noisy,
|
|
)
|
|
if control_context is not None:
|
|
if layer_id in controlnet.control_layers_mapping:
|
|
unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale
|
|
|
|
# Output
|
|
unified = dit.all_final_layer["2-1"](unified, t_noisy)
|
|
x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0]
|
|
x = rearrange(x, "C B H W -> B C H W")
|
|
x = -x
|
|
return x
|