mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support z-image controlnet
This commit is contained in:
@@ -540,6 +540,12 @@ z_image_series = [
|
||||
"model_name": "siglip_vision_model_428m",
|
||||
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
|
||||
},
|
||||
{
|
||||
# ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
|
||||
"model_hash": "1677708d40029ab380a95f6c731a57d7",
|
||||
"model_name": "z_image_controlnet",
|
||||
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
|
||||
}
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
||||
|
||||
@@ -195,4 +195,8 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
}
|
||||
|
||||
154
diffsynth/models/z_image_controlnet.py
Normal file
154
diffsynth/models/z_image_controlnet.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from .z_image_dit import ZImageTransformerBlock
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ZImageControlTransformerBlock(ZImageTransformerBlock):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int = 1000,
|
||||
dim: int = 3840,
|
||||
n_heads: int = 30,
|
||||
n_kv_heads: int = 30,
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = True,
|
||||
modulation = True,
|
||||
block_id = 0
|
||||
):
|
||||
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = nn.Linear(self.dim, self.dim)
|
||||
self.after_proj = nn.Linear(self.dim, self.dim)
|
||||
|
||||
def forward(self, c, x, **kwargs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c))
|
||||
c = all_c.pop(-1)
|
||||
|
||||
c = super().forward(c, **kwargs)
|
||||
c_skip = self.after_proj(c)
|
||||
all_c += [c_skip, c]
|
||||
c = torch.stack(all_c)
|
||||
return c
|
||||
|
||||
|
||||
class ZImageControlNet(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
|
||||
control_in_dim=33,
|
||||
dim=3840,
|
||||
n_refiner_layers=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places])
|
||||
self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)})
|
||||
self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)])
|
||||
self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}
|
||||
|
||||
def forward_layers(
|
||||
self,
|
||||
x,
|
||||
cap_feats,
|
||||
control_context,
|
||||
control_context_item_seqlens,
|
||||
kwargs,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
bsz = len(control_context)
|
||||
# unified
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
control_context_unified = []
|
||||
for i in range(bsz):
|
||||
control_context_len = control_context_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]]))
|
||||
c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
|
||||
|
||||
# arguments
|
||||
new_kwargs = dict(x=x)
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
for layer in self.control_layers:
|
||||
c = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
c=c, **new_kwargs
|
||||
)
|
||||
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
|
||||
def forward_refiner(
|
||||
self,
|
||||
dit,
|
||||
x,
|
||||
cap_feats,
|
||||
control_context,
|
||||
kwargs,
|
||||
t=None,
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
# embeddings
|
||||
bsz = len(control_context)
|
||||
device = control_context[0].device
|
||||
(
|
||||
control_context,
|
||||
control_context_size,
|
||||
control_context_pos_ids,
|
||||
control_context_inner_pad_mask,
|
||||
) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0))
|
||||
|
||||
# control_context embed & refine
|
||||
control_context_item_seqlens = [len(_) for _ in control_context]
|
||||
assert all(_ % 2 == 0 for _ in control_context_item_seqlens)
|
||||
control_context_max_item_seqlen = max(control_context_item_seqlens)
|
||||
|
||||
control_context = torch.cat(control_context, dim=0)
|
||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
|
||||
|
||||
# Match t_embedder output dtype to control_context for layerwise casting compatibility
|
||||
adaln_input = t.type_as(control_context)
|
||||
control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device)
|
||||
control_context = list(control_context.split(control_context_item_seqlens, dim=0))
|
||||
control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0))
|
||||
|
||||
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
|
||||
control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(control_context_item_seqlens):
|
||||
control_context_attn_mask[i, :seq_len] = 1
|
||||
c = control_context
|
||||
|
||||
# arguments
|
||||
new_kwargs = dict(
|
||||
x=x,
|
||||
attn_mask=control_context_attn_mask,
|
||||
freqs_cis=control_context_freqs_cis,
|
||||
adaln_input=adaln_input,
|
||||
)
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
for layer in self.control_noise_refiner:
|
||||
c = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
c=c, **new_kwargs
|
||||
)
|
||||
|
||||
hints = torch.unbind(c)[:-1]
|
||||
control_context = torch.unbind(c)[-1]
|
||||
|
||||
return hints, control_context, control_context_item_seqlens
|
||||
@@ -609,6 +609,72 @@ class ZImageDiT(nn.Module):
|
||||
# all_img_pad_mask,
|
||||
# all_cap_pad_mask,
|
||||
# )
|
||||
|
||||
def patchify_controlnet(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
patch_size: int = 2,
|
||||
f_patch_size: int = 1,
|
||||
cap_padding_len: int = None,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
device = all_image[0].device
|
||||
|
||||
all_image_out = []
|
||||
all_image_size = []
|
||||
all_image_pos_ids = []
|
||||
all_image_pad_mask = []
|
||||
|
||||
for i, image in enumerate(all_image):
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
all_image_size.append((F, H, W))
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
image_ori_pos_ids = self.create_coordinate_grid(
|
||||
size=(F_tokens, H_tokens, W_tokens),
|
||||
start=(cap_padding_len + 1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
image_padding_pos_ids = (
|
||||
self.create_coordinate_grid(
|
||||
size=(1, 1, 1),
|
||||
start=(0, 0, 0),
|
||||
device=device,
|
||||
)
|
||||
.flatten(0, 2)
|
||||
.repeat(image_padding_len, 1)
|
||||
)
|
||||
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
||||
all_image_pos_ids.append(image_padded_pos_ids)
|
||||
# pad mask
|
||||
all_image_pad_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||
all_image_out.append(image_padded_feat)
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_image_pad_mask,
|
||||
)
|
||||
|
||||
def _prepare_sequence(
|
||||
self,
|
||||
|
||||
@@ -16,6 +16,7 @@ from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||
from ..models.z_image_dit import ZImageDiT
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
|
||||
from ..models.z_image_controlnet import ZImageControlNet
|
||||
|
||||
|
||||
class ZImagePipeline(BasePipeline):
|
||||
@@ -31,8 +32,9 @@ class ZImagePipeline(BasePipeline):
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.image_encoder: Siglip2ImageEncoder428M = None
|
||||
self.controlnet: ZImageControlNet = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.in_iteration_models = ("dit", "controlnet")
|
||||
self.units = [
|
||||
ZImageUnit_ShapeChecker(),
|
||||
ZImageUnit_PromptEmbedder(),
|
||||
@@ -41,6 +43,7 @@ class ZImagePipeline(BasePipeline):
|
||||
ZImageUnit_EditImageAutoResize(),
|
||||
ZImageUnit_EditImageEmbedderVAE(),
|
||||
ZImageUnit_EditImageEmbedderSiglip(),
|
||||
ZImageUnit_PAIControlNet(),
|
||||
]
|
||||
self.model_fn = model_fn_z_image
|
||||
|
||||
@@ -63,6 +66,7 @@ class ZImagePipeline(BasePipeline):
|
||||
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
||||
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
||||
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
|
||||
pipe.controlnet = model_pool.fetch_model("z_image_controlnet")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
@@ -94,6 +98,8 @@ class ZImagePipeline(BasePipeline):
|
||||
# Steps
|
||||
num_inference_steps: int = 8,
|
||||
sigma_shift: float = None,
|
||||
# ControlNet
|
||||
controlnet_inputs: List[ControlNetInput] = None,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
@@ -114,6 +120,7 @@ class ZImagePipeline(BasePipeline):
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
||||
"controlnet_inputs": controlnet_inputs,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -331,7 +338,9 @@ class ZImageUnit_EditImageAutoResize(PipelineUnit):
|
||||
if edit_image_auto_resize is None or not edit_image_auto_resize:
|
||||
return {}
|
||||
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
|
||||
edit_image = operator(edit_image)
|
||||
if not isinstance(edit_image, list):
|
||||
edit_image = [edit_image]
|
||||
edit_image = [operator(i) for i in edit_image]
|
||||
return {"edit_image": edit_image}
|
||||
|
||||
|
||||
@@ -376,8 +385,49 @@ class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
|
||||
return {"image_latents": image_latents}
|
||||
|
||||
|
||||
class ZImageUnit_PAIControlNet(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("controlnet_inputs", "height", "width"),
|
||||
output_params=("control_context", "control_scale"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width):
|
||||
if controlnet_inputs is None:
|
||||
return {}
|
||||
if len(controlnet_inputs) != 1:
|
||||
print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.")
|
||||
controlnet_input = controlnet_inputs[0]
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
|
||||
control_image = controlnet_input.image
|
||||
if control_image is not None:
|
||||
control_image = pipe.preprocess_image(control_image)
|
||||
control_latents = pipe.vae_encoder(control_image)
|
||||
else:
|
||||
control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1
|
||||
|
||||
inpaint_mask = controlnet_input.inpaint_mask
|
||||
if inpaint_mask is not None:
|
||||
inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1)
|
||||
inpaint_image = controlnet_input.inpaint_image
|
||||
inpaint_image = pipe.preprocess_image(inpaint_image)
|
||||
inpaint_image = inpaint_image * (inpaint_mask < 0.5)
|
||||
inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1]
|
||||
else:
|
||||
inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inpaint_latent = pipe.vae_encoder(inpaint_image)
|
||||
|
||||
control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1)
|
||||
control_context = rearrange(control_context, "B C H W -> B C 1 H W")
|
||||
return {"control_context": control_context, "control_scale": controlnet_input.scale}
|
||||
|
||||
|
||||
def model_fn_z_image(
|
||||
dit: ZImageDiT,
|
||||
controlnet: ZImageControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
@@ -393,13 +443,14 @@ def model_fn_z_image(
|
||||
if dit.siglip_embedder is None:
|
||||
return model_fn_z_image_turbo(
|
||||
dit,
|
||||
latents,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
image_embeds,
|
||||
image_latents,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
controlnet=controlnet,
|
||||
latents=latents,
|
||||
timestep=timestep,
|
||||
prompt_embeds=prompt_embeds,
|
||||
image_embeds=image_embeds,
|
||||
image_latents=image_latents,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
**kwargs,
|
||||
)
|
||||
latents = [rearrange(latents, "B C H W -> C B H W")]
|
||||
@@ -431,11 +482,14 @@ def model_fn_z_image(
|
||||
|
||||
def model_fn_z_image_turbo(
|
||||
dit: ZImageDiT,
|
||||
controlnet: ZImageControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
image_embeds=None,
|
||||
image_latents=None,
|
||||
control_context=None,
|
||||
control_scale=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
@@ -460,11 +514,17 @@ def model_fn_z_image_turbo(
|
||||
|
||||
# Noise refine
|
||||
x = dit.all_x_embedder["2-1"](x)
|
||||
x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device)
|
||||
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
|
||||
x = rearrange(x, "L C -> 1 L C")
|
||||
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
|
||||
|
||||
if control_context is not None:
|
||||
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
|
||||
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
|
||||
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1)
|
||||
|
||||
for layer in dit.noise_refiner:
|
||||
for layer_id, layer in enumerate(dit.noise_refiner):
|
||||
x = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
@@ -474,6 +534,8 @@ def model_fn_z_image_turbo(
|
||||
freqs_cis=x_freqs_cis,
|
||||
adaln_input=t_noisy,
|
||||
)
|
||||
if control_context is not None:
|
||||
x = x + refiner_hints[layer_id] * control_scale
|
||||
|
||||
# Prompt refine
|
||||
cap_feats = dit.cap_embedder(cap_feats)
|
||||
@@ -495,7 +557,13 @@ def model_fn_z_image_turbo(
|
||||
# Unified
|
||||
unified = torch.cat([x, cap_feats], dim=1)
|
||||
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
|
||||
for layer in dit.layers:
|
||||
|
||||
if control_context is not None:
|
||||
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
|
||||
hints = controlnet.forward_layers(
|
||||
unified, cap_feats, control_context, control_context_item_seqlens, kwargs)
|
||||
|
||||
for layer_id, layer in enumerate(dit.layers):
|
||||
unified = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
@@ -505,6 +573,9 @@ def model_fn_z_image_turbo(
|
||||
freqs_cis=unified_freqs_cis,
|
||||
adaln_input=t_noisy,
|
||||
)
|
||||
if control_context is not None:
|
||||
if layer_id in controlnet.control_layers_mapping:
|
||||
unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale
|
||||
|
||||
# Output
|
||||
unified = dit.all_final_layer["2-1"](unified, t_noisy)
|
||||
|
||||
@@ -9,5 +9,6 @@ class ControlNetInput:
|
||||
start: float = 1.0
|
||||
end: float = 0.0
|
||||
image: Image.Image = None
|
||||
inpaint_image: Image.Image = None
|
||||
inpaint_mask: Image.Image = None
|
||||
processor_id: str = None
|
||||
|
||||
Reference in New Issue
Block a user