mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
block wise controlnet
This commit is contained in:
@@ -10,7 +10,7 @@ from ..models import ModelManager, load_state_dict
|
||||
from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
from ..models.qwen_image_controlnet import QwenImageControlNet
|
||||
from ..models.qwen_image_controlnet import QwenImageControlNet, QwenImageBlockWiseControlNet
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora import GeneralLoRALoader
|
||||
@@ -69,7 +69,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
self.controlnet: QwenImageMultiControlNet = None
|
||||
self.tokenizer: Qwen2Tokenizer = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("dit", "controlnet")
|
||||
self.in_iteration_models = ("dit", "controlnet", "blockwise_controlnet")
|
||||
self.units = [
|
||||
QwenImageUnit_ShapeChecker(),
|
||||
QwenImageUnit_NoiseInitializer(),
|
||||
@@ -226,6 +226,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
pipe.dit = model_manager.fetch_model("qwen_image_dit")
|
||||
pipe.vae = model_manager.fetch_model("qwen_image_vae")
|
||||
pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all"))
|
||||
pipe.blockwise_controlnet = model_manager.fetch_model("qwen_image_blockwise_controlnet")
|
||||
if tokenizer_config is not None and pipe.text_encoder is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
from transformers import Qwen2Tokenizer
|
||||
@@ -499,6 +500,7 @@ class QwenImageUnit_ControlNet(PipelineUnit):
|
||||
def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
|
||||
if controlnet_inputs is None:
|
||||
return {}
|
||||
return_key = "blockwise_controlnet_conditioning" if pipe.blockwise_controlnet is not None else "controlnet_conditionings"
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
conditionings = []
|
||||
for controlnet_input in controlnet_inputs:
|
||||
@@ -512,12 +514,13 @@ class QwenImageUnit_ControlNet(PipelineUnit):
|
||||
if controlnet_input.inpaint_mask is not None:
|
||||
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
|
||||
conditionings.append(image)
|
||||
return {"controlnet_conditionings": conditionings}
|
||||
return {return_key: conditionings}
|
||||
|
||||
|
||||
def model_fn_qwen_image(
|
||||
dit: QwenImageDiT = None,
|
||||
controlnet: QwenImageMultiControlNet = None,
|
||||
blockwise_controlnet: QwenImageBlockWiseControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
@@ -526,6 +529,7 @@ def model_fn_qwen_image(
|
||||
width=None,
|
||||
controlnet_inputs=None,
|
||||
controlnet_conditionings=None,
|
||||
blockwise_controlnet_conditioning=None,
|
||||
progress_id=0,
|
||||
num_inference_steps=1,
|
||||
entity_prompt_emb=None,
|
||||
@@ -572,6 +576,13 @@ def model_fn_qwen_image(
|
||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
attention_mask = None
|
||||
|
||||
if blockwise_controlnet_conditioning is not None:
|
||||
blockwise_controlnet_conditioning = rearrange(
|
||||
blockwise_controlnet_conditioning[0], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2
|
||||
)
|
||||
blockwise_controlnet_conditioning = blockwise_controlnet.process_controlnet_conditioning(blockwise_controlnet_conditioning)
|
||||
|
||||
# blockwise_controlnet_conditioning =
|
||||
for block_id, block in enumerate(dit.transformer_blocks):
|
||||
text, image = gradient_checkpoint_forward(
|
||||
block,
|
||||
@@ -584,9 +595,11 @@ def model_fn_qwen_image(
|
||||
attention_mask=attention_mask,
|
||||
enable_fp8_attention=enable_fp8_attention,
|
||||
)
|
||||
if controlnet_inputs is not None:
|
||||
if blockwise_controlnet is not None:
|
||||
image = image + blockwise_controlnet.blockwise_forward(image, blockwise_controlnet_conditioning, block_id)
|
||||
if controlnet_conditionings is not None:
|
||||
image = image + res_stack[block_id]
|
||||
|
||||
|
||||
image = dit.norm_out(image, conditioning)
|
||||
image = dit.proj_out(image)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user