mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
support z-image controlnet
This commit is contained in:
@@ -16,6 +16,7 @@ from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||
from ..models.z_image_dit import ZImageDiT
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
|
||||
from ..models.z_image_controlnet import ZImageControlNet
|
||||
|
||||
|
||||
class ZImagePipeline(BasePipeline):
|
||||
@@ -31,8 +32,9 @@ class ZImagePipeline(BasePipeline):
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.image_encoder: Siglip2ImageEncoder428M = None
|
||||
self.controlnet: ZImageControlNet = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.in_iteration_models = ("dit", "controlnet")
|
||||
self.units = [
|
||||
ZImageUnit_ShapeChecker(),
|
||||
ZImageUnit_PromptEmbedder(),
|
||||
@@ -41,6 +43,7 @@ class ZImagePipeline(BasePipeline):
|
||||
ZImageUnit_EditImageAutoResize(),
|
||||
ZImageUnit_EditImageEmbedderVAE(),
|
||||
ZImageUnit_EditImageEmbedderSiglip(),
|
||||
ZImageUnit_PAIControlNet(),
|
||||
]
|
||||
self.model_fn = model_fn_z_image
|
||||
|
||||
@@ -63,6 +66,7 @@ class ZImagePipeline(BasePipeline):
|
||||
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
||||
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
||||
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
|
||||
pipe.controlnet = model_pool.fetch_model("z_image_controlnet")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
@@ -94,6 +98,8 @@ class ZImagePipeline(BasePipeline):
|
||||
# Steps
|
||||
num_inference_steps: int = 8,
|
||||
sigma_shift: float = None,
|
||||
# ControlNet
|
||||
controlnet_inputs: List[ControlNetInput] = None,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
@@ -114,6 +120,7 @@ class ZImagePipeline(BasePipeline):
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
||||
"controlnet_inputs": controlnet_inputs,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -331,7 +338,9 @@ class ZImageUnit_EditImageAutoResize(PipelineUnit):
|
||||
if edit_image_auto_resize is None or not edit_image_auto_resize:
|
||||
return {}
|
||||
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
|
||||
edit_image = operator(edit_image)
|
||||
if not isinstance(edit_image, list):
|
||||
edit_image = [edit_image]
|
||||
edit_image = [operator(i) for i in edit_image]
|
||||
return {"edit_image": edit_image}
|
||||
|
||||
|
||||
@@ -376,8 +385,49 @@ class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
|
||||
return {"image_latents": image_latents}
|
||||
|
||||
|
||||
class ZImageUnit_PAIControlNet(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("controlnet_inputs", "height", "width"),
|
||||
output_params=("control_context", "control_scale"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width):
|
||||
if controlnet_inputs is None:
|
||||
return {}
|
||||
if len(controlnet_inputs) != 1:
|
||||
print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.")
|
||||
controlnet_input = controlnet_inputs[0]
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
|
||||
control_image = controlnet_input.image
|
||||
if control_image is not None:
|
||||
control_image = pipe.preprocess_image(control_image)
|
||||
control_latents = pipe.vae_encoder(control_image)
|
||||
else:
|
||||
control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1
|
||||
|
||||
inpaint_mask = controlnet_input.inpaint_mask
|
||||
if inpaint_mask is not None:
|
||||
inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1)
|
||||
inpaint_image = controlnet_input.inpaint_image
|
||||
inpaint_image = pipe.preprocess_image(inpaint_image)
|
||||
inpaint_image = inpaint_image * (inpaint_mask < 0.5)
|
||||
inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1]
|
||||
else:
|
||||
inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inpaint_latent = pipe.vae_encoder(inpaint_image)
|
||||
|
||||
control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1)
|
||||
control_context = rearrange(control_context, "B C H W -> B C 1 H W")
|
||||
return {"control_context": control_context, "control_scale": controlnet_input.scale}
|
||||
|
||||
|
||||
def model_fn_z_image(
|
||||
dit: ZImageDiT,
|
||||
controlnet: ZImageControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
@@ -393,13 +443,14 @@ def model_fn_z_image(
|
||||
if dit.siglip_embedder is None:
|
||||
return model_fn_z_image_turbo(
|
||||
dit,
|
||||
latents,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
image_embeds,
|
||||
image_latents,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
controlnet=controlnet,
|
||||
latents=latents,
|
||||
timestep=timestep,
|
||||
prompt_embeds=prompt_embeds,
|
||||
image_embeds=image_embeds,
|
||||
image_latents=image_latents,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
**kwargs,
|
||||
)
|
||||
latents = [rearrange(latents, "B C H W -> C B H W")]
|
||||
@@ -431,11 +482,14 @@ def model_fn_z_image(
|
||||
|
||||
def model_fn_z_image_turbo(
|
||||
dit: ZImageDiT,
|
||||
controlnet: ZImageControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
image_embeds=None,
|
||||
image_latents=None,
|
||||
control_context=None,
|
||||
control_scale=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
@@ -460,11 +514,17 @@ def model_fn_z_image_turbo(
|
||||
|
||||
# Noise refine
|
||||
x = dit.all_x_embedder["2-1"](x)
|
||||
x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device)
|
||||
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
|
||||
x = rearrange(x, "L C -> 1 L C")
|
||||
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
|
||||
|
||||
if control_context is not None:
|
||||
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
|
||||
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
|
||||
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1)
|
||||
|
||||
for layer in dit.noise_refiner:
|
||||
for layer_id, layer in enumerate(dit.noise_refiner):
|
||||
x = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
@@ -474,6 +534,8 @@ def model_fn_z_image_turbo(
|
||||
freqs_cis=x_freqs_cis,
|
||||
adaln_input=t_noisy,
|
||||
)
|
||||
if control_context is not None:
|
||||
x = x + refiner_hints[layer_id] * control_scale
|
||||
|
||||
# Prompt refine
|
||||
cap_feats = dit.cap_embedder(cap_feats)
|
||||
@@ -495,7 +557,13 @@ def model_fn_z_image_turbo(
|
||||
# Unified
|
||||
unified = torch.cat([x, cap_feats], dim=1)
|
||||
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
|
||||
for layer in dit.layers:
|
||||
|
||||
if control_context is not None:
|
||||
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
|
||||
hints = controlnet.forward_layers(
|
||||
unified, cap_feats, control_context, control_context_item_seqlens, kwargs)
|
||||
|
||||
for layer_id, layer in enumerate(dit.layers):
|
||||
unified = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
@@ -505,6 +573,9 @@ def model_fn_z_image_turbo(
|
||||
freqs_cis=unified_freqs_cis,
|
||||
adaln_input=t_noisy,
|
||||
)
|
||||
if control_context is not None:
|
||||
if layer_id in controlnet.control_layers_mapping:
|
||||
unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale
|
||||
|
||||
# Output
|
||||
unified = dit.all_final_layer["2-1"](unified, t_noisy)
|
||||
|
||||
Reference in New Issue
Block a user