mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
block wise controlnet
This commit is contained in:
@@ -76,6 +76,7 @@ from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
from ..models.qwen_image_controlnet import QwenImageControlNet
|
||||
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -169,6 +170,7 @@ model_loader_configs = [
|
||||
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
|
||||
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
||||
(None, "be2500a62936a43d5367a70ea001e25d", ["qwen_image_controlnet"], [QwenImageControlNet], "civitai"),
|
||||
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
@@ -93,3 +93,67 @@ class QwenImageControlNetStateDictConverter():
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
# [linear, gelu, linear]
|
||||
def __init__(self, dim: int = 3072):
|
||||
super().__init__()
|
||||
self.x_rms = RMSNorm(dim, eps=1e-6)
|
||||
self.y_rms = RMSNorm(dim, eps=1e-6)
|
||||
self.input_proj = nn.Linear(dim, dim)
|
||||
self.act = nn.GELU()
|
||||
self.output_proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x, y):
|
||||
x, y = self.x_rms(x), self.y_rms(y)
|
||||
x = self.input_proj(x + y)
|
||||
x = self.act(x)
|
||||
x = self.output_proj(x)
|
||||
return x
|
||||
|
||||
def init_weights(self):
|
||||
# zero initialize output_proj
|
||||
nn.init.zeros_(self.output_proj.weight)
|
||||
nn.init.zeros_(self.output_proj.bias)
|
||||
|
||||
|
||||
class QwenImageBlockWiseControlNet(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
in_dim: int = 64,
|
||||
dim: int = 3072,
|
||||
):
|
||||
super().__init__()
|
||||
self.img_in = nn.Linear(in_dim, dim)
|
||||
self.controlnet_blocks = nn.ModuleList(
|
||||
[
|
||||
BlockWiseControlBlock(dim)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def init_weight(self):
|
||||
nn.init.zeros_(self.img_in.weight)
|
||||
nn.init.zeros_(self.img_in.bias)
|
||||
for block in self.controlnet_blocks:
|
||||
block.init_weights()
|
||||
|
||||
def process_controlnet_conditioning(self, controlnet_conditioning):
|
||||
return self.img_in(controlnet_conditioning)
|
||||
|
||||
def blockwise_forward(self, img, controlnet_conditioning, block_id):
|
||||
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageBlockWiseControlNetStateDictConverter()
|
||||
|
||||
|
||||
class QwenImageBlockWiseControlNetStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
@@ -10,7 +10,7 @@ from ..models import ModelManager, load_state_dict
|
||||
from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
from ..models.qwen_image_controlnet import QwenImageControlNet
|
||||
from ..models.qwen_image_controlnet import QwenImageControlNet, QwenImageBlockWiseControlNet
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora import GeneralLoRALoader
|
||||
@@ -69,7 +69,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
self.controlnet: QwenImageMultiControlNet = None
|
||||
self.tokenizer: Qwen2Tokenizer = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("dit", "controlnet")
|
||||
self.in_iteration_models = ("dit", "controlnet", "blockwise_controlnet")
|
||||
self.units = [
|
||||
QwenImageUnit_ShapeChecker(),
|
||||
QwenImageUnit_NoiseInitializer(),
|
||||
@@ -226,6 +226,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
pipe.dit = model_manager.fetch_model("qwen_image_dit")
|
||||
pipe.vae = model_manager.fetch_model("qwen_image_vae")
|
||||
pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all"))
|
||||
pipe.blockwise_controlnet = model_manager.fetch_model("qwen_image_blockwise_controlnet")
|
||||
if tokenizer_config is not None and pipe.text_encoder is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
from transformers import Qwen2Tokenizer
|
||||
@@ -499,6 +500,7 @@ class QwenImageUnit_ControlNet(PipelineUnit):
|
||||
def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
|
||||
if controlnet_inputs is None:
|
||||
return {}
|
||||
return_key = "blockwise_controlnet_conditioning" if pipe.blockwise_controlnet is not None else "controlnet_conditionings"
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
conditionings = []
|
||||
for controlnet_input in controlnet_inputs:
|
||||
@@ -512,12 +514,13 @@ class QwenImageUnit_ControlNet(PipelineUnit):
|
||||
if controlnet_input.inpaint_mask is not None:
|
||||
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
|
||||
conditionings.append(image)
|
||||
return {"controlnet_conditionings": conditionings}
|
||||
return {return_key: conditionings}
|
||||
|
||||
|
||||
def model_fn_qwen_image(
|
||||
dit: QwenImageDiT = None,
|
||||
controlnet: QwenImageMultiControlNet = None,
|
||||
blockwise_controlnet: QwenImageBlockWiseControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
@@ -526,6 +529,7 @@ def model_fn_qwen_image(
|
||||
width=None,
|
||||
controlnet_inputs=None,
|
||||
controlnet_conditionings=None,
|
||||
blockwise_controlnet_conditioning=None,
|
||||
progress_id=0,
|
||||
num_inference_steps=1,
|
||||
entity_prompt_emb=None,
|
||||
@@ -572,6 +576,13 @@ def model_fn_qwen_image(
|
||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
attention_mask = None
|
||||
|
||||
if blockwise_controlnet_conditioning is not None:
|
||||
blockwise_controlnet_conditioning = rearrange(
|
||||
blockwise_controlnet_conditioning[0], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2
|
||||
)
|
||||
blockwise_controlnet_conditioning = blockwise_controlnet.process_controlnet_conditioning(blockwise_controlnet_conditioning)
|
||||
|
||||
# blockwise_controlnet_conditioning =
|
||||
for block_id, block in enumerate(dit.transformer_blocks):
|
||||
text, image = gradient_checkpoint_forward(
|
||||
block,
|
||||
@@ -584,9 +595,11 @@ def model_fn_qwen_image(
|
||||
attention_mask=attention_mask,
|
||||
enable_fp8_attention=enable_fp8_attention,
|
||||
)
|
||||
if controlnet_inputs is not None:
|
||||
if blockwise_controlnet is not None:
|
||||
image = image + blockwise_controlnet.blockwise_forward(image, blockwise_controlnet_conditioning, block_id)
|
||||
if controlnet_conditionings is not None:
|
||||
image = image + res_stack[block_id]
|
||||
|
||||
|
||||
image = dit.norm_out(image, conditioning)
|
||||
image = dit.proj_out(image)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user