support z-image controlnet

This commit is contained in:
Artiprocher
2026-01-07 15:56:53 +08:00
parent 32449a6aa0
commit bac39b1cd2
28 changed files with 868 additions and 11 deletions

View File

@@ -540,6 +540,12 @@ z_image_series = [
"model_name": "siglip_vision_model_428m",
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
},
{
# ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
"model_hash": "1677708d40029ab380a95f6c731a57d7",
"model_name": "z_image_controlnet",
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
}
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series

View File

@@ -195,4 +195,8 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
}

View File

@@ -0,0 +1,154 @@
from .z_image_dit import ZImageTransformerBlock
from ..core.gradient import gradient_checkpoint_forward
from torch.nn.utils.rnn import pad_sequence
import torch
from torch import nn
class ZImageControlTransformerBlock(ZImageTransformerBlock):
def __init__(
self,
layer_id: int = 1000,
dim: int = 3840,
n_heads: int = 30,
n_kv_heads: int = 30,
norm_eps: float = 1e-5,
qk_norm: bool = True,
modulation = True,
block_id = 0
):
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
self.block_id = block_id
if block_id == 0:
self.before_proj = nn.Linear(self.dim, self.dim)
self.after_proj = nn.Linear(self.dim, self.dim)
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = list(torch.unbind(c))
c = all_c.pop(-1)
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
c = torch.stack(all_c)
return c
class ZImageControlNet(torch.nn.Module):
def __init__(
self,
control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
control_in_dim=33,
dim=3840,
n_refiner_layers=2,
):
super().__init__()
self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places])
self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)})
self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)])
self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}
def forward_layers(
self,
x,
cap_feats,
control_context,
control_context_item_seqlens,
kwargs,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
bsz = len(control_context)
# unified
cap_item_seqlens = [len(_) for _ in cap_feats]
control_context_unified = []
for i in range(bsz):
control_context_len = control_context_item_seqlens[i]
cap_len = cap_item_seqlens[i]
control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]]))
c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
for layer in self.control_layers:
c = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
c=c, **new_kwargs
)
hints = torch.unbind(c)[:-1]
return hints
def forward_refiner(
self,
dit,
x,
cap_feats,
control_context,
kwargs,
t=None,
patch_size=2,
f_patch_size=1,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
# embeddings
bsz = len(control_context)
device = control_context[0].device
(
control_context,
control_context_size,
control_context_pos_ids,
control_context_inner_pad_mask,
) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0))
# control_context embed & refine
control_context_item_seqlens = [len(_) for _ in control_context]
assert all(_ % 2 == 0 for _ in control_context_item_seqlens)
control_context_max_item_seqlen = max(control_context_item_seqlens)
control_context = torch.cat(control_context, dim=0)
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
# Match t_embedder output dtype to control_context for layerwise casting compatibility
adaln_input = t.type_as(control_context)
control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device)
control_context = list(control_context.split(control_context_item_seqlens, dim=0))
control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0))
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0)
control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(control_context_item_seqlens):
control_context_attn_mask[i, :seq_len] = 1
c = control_context
# arguments
new_kwargs = dict(
x=x,
attn_mask=control_context_attn_mask,
freqs_cis=control_context_freqs_cis,
adaln_input=adaln_input,
)
new_kwargs.update(kwargs)
for layer in self.control_noise_refiner:
c = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
c=c, **new_kwargs
)
hints = torch.unbind(c)[:-1]
control_context = torch.unbind(c)[-1]
return hints, control_context, control_context_item_seqlens

View File

@@ -609,6 +609,72 @@ class ZImageDiT(nn.Module):
# all_img_pad_mask,
# all_cap_pad_mask,
# )
def patchify_controlnet(
self,
all_image: List[torch.Tensor],
patch_size: int = 2,
f_patch_size: int = 1,
cap_padding_len: int = None,
):
pH = pW = patch_size
pF = f_patch_size
device = all_image[0].device
all_image_out = []
all_image_size = []
all_image_pos_ids = []
all_image_pad_mask = []
for i, image in enumerate(all_image):
### Process Image
C, F, H, W = image.size()
all_image_size.append((F, H, W))
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
image_ori_len = len(image)
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
image_ori_pos_ids = self.create_coordinate_grid(
size=(F_tokens, H_tokens, W_tokens),
start=(cap_padding_len + 1, 0, 0),
device=device,
).flatten(0, 2)
image_padding_pos_ids = (
self.create_coordinate_grid(
size=(1, 1, 1),
start=(0, 0, 0),
device=device,
)
.flatten(0, 2)
.repeat(image_padding_len, 1)
)
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
all_image_pos_ids.append(image_padded_pos_ids)
# pad mask
all_image_pad_mask.append(
torch.cat(
[
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
)
# padded feature
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
all_image_out.append(image_padded_feat)
return (
all_image_out,
all_image_size,
all_image_pos_ids,
all_image_pad_mask,
)
def _prepare_sequence(
self,

View File

@@ -16,6 +16,7 @@ from ..models.z_image_text_encoder import ZImageTextEncoder
from ..models.z_image_dit import ZImageDiT
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
from ..models.z_image_controlnet import ZImageControlNet
class ZImagePipeline(BasePipeline):
@@ -31,8 +32,9 @@ class ZImagePipeline(BasePipeline):
self.vae_encoder: FluxVAEEncoder = None
self.vae_decoder: FluxVAEDecoder = None
self.image_encoder: Siglip2ImageEncoder428M = None
self.controlnet: ZImageControlNet = None
self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("dit",)
self.in_iteration_models = ("dit", "controlnet")
self.units = [
ZImageUnit_ShapeChecker(),
ZImageUnit_PromptEmbedder(),
@@ -41,6 +43,7 @@ class ZImagePipeline(BasePipeline):
ZImageUnit_EditImageAutoResize(),
ZImageUnit_EditImageEmbedderVAE(),
ZImageUnit_EditImageEmbedderSiglip(),
ZImageUnit_PAIControlNet(),
]
self.model_fn = model_fn_z_image
@@ -63,6 +66,7 @@ class ZImagePipeline(BasePipeline):
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
pipe.controlnet = model_pool.fetch_model("z_image_controlnet")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
@@ -94,6 +98,8 @@ class ZImagePipeline(BasePipeline):
# Steps
num_inference_steps: int = 8,
sigma_shift: float = None,
# ControlNet
controlnet_inputs: List[ControlNetInput] = None,
# Progress bar
progress_bar_cmd = tqdm,
):
@@ -114,6 +120,7 @@ class ZImagePipeline(BasePipeline):
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
"controlnet_inputs": controlnet_inputs,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -331,7 +338,9 @@ class ZImageUnit_EditImageAutoResize(PipelineUnit):
if edit_image_auto_resize is None or not edit_image_auto_resize:
return {}
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
edit_image = operator(edit_image)
if not isinstance(edit_image, list):
edit_image = [edit_image]
edit_image = [operator(i) for i in edit_image]
return {"edit_image": edit_image}
@@ -376,8 +385,49 @@ class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
return {"image_latents": image_latents}
class ZImageUnit_PAIControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("controlnet_inputs", "height", "width"),
output_params=("control_context", "control_scale"),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width):
if controlnet_inputs is None:
return {}
if len(controlnet_inputs) != 1:
print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.")
controlnet_input = controlnet_inputs[0]
pipe.load_models_to_device(self.onload_model_names)
control_image = controlnet_input.image
if control_image is not None:
control_image = pipe.preprocess_image(control_image)
control_latents = pipe.vae_encoder(control_image)
else:
control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1
inpaint_mask = controlnet_input.inpaint_mask
if inpaint_mask is not None:
inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1)
inpaint_image = controlnet_input.inpaint_image
inpaint_image = pipe.preprocess_image(inpaint_image)
inpaint_image = inpaint_image * (inpaint_mask < 0.5)
inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1]
else:
inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device)
inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device)
inpaint_latent = pipe.vae_encoder(inpaint_image)
control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1)
control_context = rearrange(control_context, "B C H W -> B C 1 H W")
return {"control_context": control_context, "control_scale": controlnet_input.scale}
def model_fn_z_image(
dit: ZImageDiT,
controlnet: ZImageControlNet = None,
latents=None,
timestep=None,
prompt_embeds=None,
@@ -393,13 +443,14 @@ def model_fn_z_image(
if dit.siglip_embedder is None:
return model_fn_z_image_turbo(
dit,
latents,
timestep,
prompt_embeds,
image_embeds,
image_latents,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
controlnet=controlnet,
latents=latents,
timestep=timestep,
prompt_embeds=prompt_embeds,
image_embeds=image_embeds,
image_latents=image_latents,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
**kwargs,
)
latents = [rearrange(latents, "B C H W -> C B H W")]
@@ -431,11 +482,14 @@ def model_fn_z_image(
def model_fn_z_image_turbo(
dit: ZImageDiT,
controlnet: ZImageControlNet = None,
latents=None,
timestep=None,
prompt_embeds=None,
image_embeds=None,
image_latents=None,
control_context=None,
control_scale=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
@@ -460,11 +514,17 @@ def model_fn_z_image_turbo(
# Noise refine
x = dit.all_x_embedder["2-1"](x)
x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device)
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
x = rearrange(x, "L C -> 1 L C")
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
if control_context is not None:
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1)
for layer in dit.noise_refiner:
for layer_id, layer in enumerate(dit.noise_refiner):
x = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
@@ -474,6 +534,8 @@ def model_fn_z_image_turbo(
freqs_cis=x_freqs_cis,
adaln_input=t_noisy,
)
if control_context is not None:
x = x + refiner_hints[layer_id] * control_scale
# Prompt refine
cap_feats = dit.cap_embedder(cap_feats)
@@ -495,7 +557,13 @@ def model_fn_z_image_turbo(
# Unified
unified = torch.cat([x, cap_feats], dim=1)
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
for layer in dit.layers:
if control_context is not None:
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
hints = controlnet.forward_layers(
unified, cap_feats, control_context, control_context_item_seqlens, kwargs)
for layer_id, layer in enumerate(dit.layers):
unified = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
@@ -505,6 +573,9 @@ def model_fn_z_image_turbo(
freqs_cis=unified_freqs_cis,
adaln_input=t_noisy,
)
if control_context is not None:
if layer_id in controlnet.control_layers_mapping:
unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale
# Output
unified = dit.all_final_layer["2-1"](unified, t_noisy)

View File

@@ -9,5 +9,6 @@ class ControlNetInput:
start: float = 1.0
end: float = 0.0
image: Image.Image = None
inpaint_image: Image.Image = None
inpaint_mask: Image.Image = None
processor_id: str = None