qwen-image controlnet

This commit is contained in:
Artiprocher
2025-08-08 11:29:23 +08:00
parent 32cf5d32ce
commit 6e13deb6de
6 changed files with 284 additions and 3 deletions

View File

@@ -4,19 +4,55 @@ from typing import Union
from PIL import Image
from tqdm import tqdm
from einops import rearrange
import numpy as np
from ..models import ModelManager, load_state_dict
from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE
from ..models.qwen_image_controlnet import QwenImageControlNet
from ..schedulers import FlowMatchScheduler
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora import GeneralLoRALoader
from .flux_image_new import ControlNetInput
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
class QwenImageMultiControlNet(torch.nn.Module):
def __init__(self, models: list[QwenImageControlNet]):
super().__init__()
if not isinstance(models, list):
models = [models]
self.models = torch.nn.ModuleList(models)
def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs):
model = self.models[controlnet_input.controlnet_id]
res_stack = model(
controlnet_conditioning=conditioning,
processor_id=controlnet_input.processor_id,
**kwargs
)
res_stack = [res * controlnet_input.scale for res in res_stack]
return res_stack
def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs):
res_stack = None
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4):
continue
res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs)
if res_stack is None:
res_stack = res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack
class QwenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
@@ -30,15 +66,17 @@ class QwenImagePipeline(BasePipeline):
self.text_encoder: QwenImageTextEncoder = None
self.dit: QwenImageDiT = None
self.vae: QwenImageVAE = None
self.controlnet: QwenImageMultiControlNet = None
self.tokenizer: Qwen2Tokenizer = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit",)
self.in_iteration_models = ("dit", "controlnet")
self.units = [
QwenImageUnit_ShapeChecker(),
QwenImageUnit_NoiseInitializer(),
QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(),
QwenImageUnit_ControlNet(),
]
self.model_fn = model_fn_qwen_image
@@ -187,6 +225,7 @@ class QwenImagePipeline(BasePipeline):
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
pipe.dit = model_manager.fetch_model("qwen_image_dit")
pipe.vae = model_manager.fetch_model("qwen_image_vae")
pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all"))
if tokenizer_config is not None and pipe.text_encoder is not None:
tokenizer_config.download_if_necessary()
from transformers import Qwen2Tokenizer
@@ -212,6 +251,8 @@ class QwenImagePipeline(BasePipeline):
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
# ControlNet
controlnet_inputs: list[ControlNetInput] = None,
# EliGen
eligen_entity_prompts: list[str] = None,
eligen_entity_masks: list[Image.Image] = None,
@@ -241,6 +282,8 @@ class QwenImagePipeline(BasePipeline):
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"enable_fp8_attention": enable_fp8_attention,
"num_inference_steps": num_inference_steps,
"controlnet_inputs": controlnet_inputs,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
}
@@ -431,14 +474,60 @@ class QwenImageUnit_EntityControl(PipelineUnit):
return inputs_shared, inputs_posi, inputs_nega
class QwenImageUnit_ControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
mask = (pipe.preprocess_image(mask) + 1) / 2
mask = mask.mean(dim=1, keepdim=True)
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
latents = torch.concat([latents, mask], dim=1)
return latents
def apply_controlnet_mask_on_image(self, pipe, image, mask):
mask = mask.resize(image.size)
mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
image = np.array(image)
image[mask > 0] = 0
image = Image.fromarray(image)
return image
def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
if controlnet_inputs is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
conditionings = []
for controlnet_input in controlnet_inputs:
image = controlnet_input.image
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
conditionings.append(image)
return {"controlnet_conditionings": conditionings}
def model_fn_qwen_image(
dit: QwenImageDiT = None,
controlnet: QwenImageMultiControlNet = None,
latents=None,
timestep=None,
prompt_emb=None,
prompt_emb_mask=None,
height=None,
width=None,
controlnet_inputs=None,
controlnet_conditionings=None,
progress_id=0,
num_inference_steps=1,
entity_prompt_emb=None,
entity_prompt_emb_mask=None,
entity_masks=None,
@@ -447,6 +536,23 @@ def model_fn_qwen_image(
use_gradient_checkpointing_offload=False,
**kwargs
):
# ControlNet
if controlnet_conditionings is not None:
controlnet_extra_kwargs = {
"latents": latents,
"timestep": timestep,
"prompt_emb": prompt_emb,
"prompt_emb_mask": prompt_emb_mask,
"height": height,
"width": width,
"use_gradient_checkpointing": use_gradient_checkpointing,
"use_gradient_checkpointing_offload": use_gradient_checkpointing_offload,
}
res_stack = controlnet(
controlnet_conditionings, controlnet_inputs, progress_id, num_inference_steps,
**controlnet_extra_kwargs
)
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
timestep = timestep / 1000
@@ -466,7 +572,7 @@ def model_fn_qwen_image(
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
attention_mask = None
for block in dit.transformer_blocks:
for block_id, block in enumerate(dit.transformer_blocks):
text, image = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
@@ -478,6 +584,8 @@ def model_fn_qwen_image(
attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
)
if controlnet_inputs is not None:
image = image + res_stack[block_id]
image = dit.norm_out(image, conditioning)
image = dit.proj_out(image)