mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
block wise controlnet
This commit is contained in:
@@ -76,6 +76,7 @@ from ..models.qwen_image_dit import QwenImageDiT
|
|||||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||||
from ..models.qwen_image_vae import QwenImageVAE
|
from ..models.qwen_image_vae import QwenImageVAE
|
||||||
from ..models.qwen_image_controlnet import QwenImageControlNet
|
from ..models.qwen_image_controlnet import QwenImageControlNet
|
||||||
|
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||||
|
|
||||||
model_loader_configs = [
|
model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
@@ -169,6 +170,7 @@ model_loader_configs = [
|
|||||||
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
|
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
|
||||||
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
||||||
(None, "be2500a62936a43d5367a70ea001e25d", ["qwen_image_controlnet"], [QwenImageControlNet], "civitai"),
|
(None, "be2500a62936a43d5367a70ea001e25d", ["qwen_image_controlnet"], [QwenImageControlNet], "civitai"),
|
||||||
|
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||||
]
|
]
|
||||||
huggingface_model_loader_configs = [
|
huggingface_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
|
|||||||
@@ -93,3 +93,67 @@ class QwenImageControlNetStateDictConverter():
|
|||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
class BlockWiseControlBlock(torch.nn.Module):
|
||||||
|
# [linear, gelu, linear]
|
||||||
|
def __init__(self, dim: int = 3072):
|
||||||
|
super().__init__()
|
||||||
|
self.x_rms = RMSNorm(dim, eps=1e-6)
|
||||||
|
self.y_rms = RMSNorm(dim, eps=1e-6)
|
||||||
|
self.input_proj = nn.Linear(dim, dim)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.output_proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
x, y = self.x_rms(x), self.y_rms(y)
|
||||||
|
x = self.input_proj(x + y)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.output_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
# zero initialize output_proj
|
||||||
|
nn.init.zeros_(self.output_proj.weight)
|
||||||
|
nn.init.zeros_(self.output_proj.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageBlockWiseControlNet(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers: int = 60,
|
||||||
|
in_dim: int = 64,
|
||||||
|
dim: int = 3072,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.img_in = nn.Linear(in_dim, dim)
|
||||||
|
self.controlnet_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BlockWiseControlBlock(dim)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_weight(self):
|
||||||
|
nn.init.zeros_(self.img_in.weight)
|
||||||
|
nn.init.zeros_(self.img_in.bias)
|
||||||
|
for block in self.controlnet_blocks:
|
||||||
|
block.init_weights()
|
||||||
|
|
||||||
|
def process_controlnet_conditioning(self, controlnet_conditioning):
|
||||||
|
return self.img_in(controlnet_conditioning)
|
||||||
|
|
||||||
|
def blockwise_forward(self, img, controlnet_conditioning, block_id):
|
||||||
|
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return QwenImageBlockWiseControlNetStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageBlockWiseControlNetStateDictConverter():
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from ..models import ModelManager, load_state_dict
|
|||||||
from ..models.qwen_image_dit import QwenImageDiT
|
from ..models.qwen_image_dit import QwenImageDiT
|
||||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||||
from ..models.qwen_image_vae import QwenImageVAE
|
from ..models.qwen_image_vae import QwenImageVAE
|
||||||
from ..models.qwen_image_controlnet import QwenImageControlNet
|
from ..models.qwen_image_controlnet import QwenImageControlNet, QwenImageBlockWiseControlNet
|
||||||
from ..schedulers import FlowMatchScheduler
|
from ..schedulers import FlowMatchScheduler
|
||||||
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||||
from ..lora import GeneralLoRALoader
|
from ..lora import GeneralLoRALoader
|
||||||
@@ -69,7 +69,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
self.controlnet: QwenImageMultiControlNet = None
|
self.controlnet: QwenImageMultiControlNet = None
|
||||||
self.tokenizer: Qwen2Tokenizer = None
|
self.tokenizer: Qwen2Tokenizer = None
|
||||||
self.unit_runner = PipelineUnitRunner()
|
self.unit_runner = PipelineUnitRunner()
|
||||||
self.in_iteration_models = ("dit", "controlnet")
|
self.in_iteration_models = ("dit", "controlnet", "blockwise_controlnet")
|
||||||
self.units = [
|
self.units = [
|
||||||
QwenImageUnit_ShapeChecker(),
|
QwenImageUnit_ShapeChecker(),
|
||||||
QwenImageUnit_NoiseInitializer(),
|
QwenImageUnit_NoiseInitializer(),
|
||||||
@@ -226,6 +226,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
pipe.dit = model_manager.fetch_model("qwen_image_dit")
|
pipe.dit = model_manager.fetch_model("qwen_image_dit")
|
||||||
pipe.vae = model_manager.fetch_model("qwen_image_vae")
|
pipe.vae = model_manager.fetch_model("qwen_image_vae")
|
||||||
pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all"))
|
pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all"))
|
||||||
|
pipe.blockwise_controlnet = model_manager.fetch_model("qwen_image_blockwise_controlnet")
|
||||||
if tokenizer_config is not None and pipe.text_encoder is not None:
|
if tokenizer_config is not None and pipe.text_encoder is not None:
|
||||||
tokenizer_config.download_if_necessary()
|
tokenizer_config.download_if_necessary()
|
||||||
from transformers import Qwen2Tokenizer
|
from transformers import Qwen2Tokenizer
|
||||||
@@ -499,6 +500,7 @@ class QwenImageUnit_ControlNet(PipelineUnit):
|
|||||||
def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
|
def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
|
||||||
if controlnet_inputs is None:
|
if controlnet_inputs is None:
|
||||||
return {}
|
return {}
|
||||||
|
return_key = "blockwise_controlnet_conditioning" if pipe.blockwise_controlnet is not None else "controlnet_conditionings"
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
conditionings = []
|
conditionings = []
|
||||||
for controlnet_input in controlnet_inputs:
|
for controlnet_input in controlnet_inputs:
|
||||||
@@ -512,12 +514,13 @@ class QwenImageUnit_ControlNet(PipelineUnit):
|
|||||||
if controlnet_input.inpaint_mask is not None:
|
if controlnet_input.inpaint_mask is not None:
|
||||||
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
|
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
|
||||||
conditionings.append(image)
|
conditionings.append(image)
|
||||||
return {"controlnet_conditionings": conditionings}
|
return {return_key: conditionings}
|
||||||
|
|
||||||
|
|
||||||
def model_fn_qwen_image(
|
def model_fn_qwen_image(
|
||||||
dit: QwenImageDiT = None,
|
dit: QwenImageDiT = None,
|
||||||
controlnet: QwenImageMultiControlNet = None,
|
controlnet: QwenImageMultiControlNet = None,
|
||||||
|
blockwise_controlnet: QwenImageBlockWiseControlNet = None,
|
||||||
latents=None,
|
latents=None,
|
||||||
timestep=None,
|
timestep=None,
|
||||||
prompt_emb=None,
|
prompt_emb=None,
|
||||||
@@ -526,6 +529,7 @@ def model_fn_qwen_image(
|
|||||||
width=None,
|
width=None,
|
||||||
controlnet_inputs=None,
|
controlnet_inputs=None,
|
||||||
controlnet_conditionings=None,
|
controlnet_conditionings=None,
|
||||||
|
blockwise_controlnet_conditioning=None,
|
||||||
progress_id=0,
|
progress_id=0,
|
||||||
num_inference_steps=1,
|
num_inference_steps=1,
|
||||||
entity_prompt_emb=None,
|
entity_prompt_emb=None,
|
||||||
@@ -572,6 +576,13 @@ def model_fn_qwen_image(
|
|||||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
|
|
||||||
|
if blockwise_controlnet_conditioning is not None:
|
||||||
|
blockwise_controlnet_conditioning = rearrange(
|
||||||
|
blockwise_controlnet_conditioning[0], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2
|
||||||
|
)
|
||||||
|
blockwise_controlnet_conditioning = blockwise_controlnet.process_controlnet_conditioning(blockwise_controlnet_conditioning)
|
||||||
|
|
||||||
|
# blockwise_controlnet_conditioning =
|
||||||
for block_id, block in enumerate(dit.transformer_blocks):
|
for block_id, block in enumerate(dit.transformer_blocks):
|
||||||
text, image = gradient_checkpoint_forward(
|
text, image = gradient_checkpoint_forward(
|
||||||
block,
|
block,
|
||||||
@@ -584,7 +595,9 @@ def model_fn_qwen_image(
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
enable_fp8_attention=enable_fp8_attention,
|
enable_fp8_attention=enable_fp8_attention,
|
||||||
)
|
)
|
||||||
if controlnet_inputs is not None:
|
if blockwise_controlnet is not None:
|
||||||
|
image = image + blockwise_controlnet.blockwise_forward(image, blockwise_controlnet_conditioning, block_id)
|
||||||
|
if controlnet_conditionings is not None:
|
||||||
image = image + res_stack[block_id]
|
image = image + res_stack[block_id]
|
||||||
|
|
||||||
image = dit.norm_out(image, conditioning)
|
image = dit.norm_out(image, conditioning)
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \
|
||||||
|
--dataset_base_path "" \
|
||||||
|
--dataset_metadata_path data/t2i_dataset_annotations/blip3o/blip3o_control_images_train_for_diffsynth.jsonl \
|
||||||
|
--data_file_keys "image,controlnet_image" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_paths '[
|
||||||
|
[
|
||||||
|
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
|
||||||
|
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
|
||||||
|
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
|
||||||
|
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
|
||||||
|
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
|
||||||
|
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
|
||||||
|
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
|
||||||
|
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
|
||||||
|
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||||
|
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||||
|
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||||
|
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||||
|
],
|
||||||
|
"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors",
|
||||||
|
"models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors"
|
||||||
|
]' \
|
||||||
|
--learning_rate 1e-3 \
|
||||||
|
--num_epochs 1000000 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.blockwise_controlnet." \
|
||||||
|
--output_path "./models/train/Qwen-Image-BlockWiseControlNet_full_lr1e-3_wd1e-6" \
|
||||||
|
--trainable_models "blockwise_controlnet" \
|
||||||
|
--extra_inputs "controlnet_image" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--save_steps 2000
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: false
|
||||||
|
zero_stage: 2
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# This script is for initializing a Qwen-Image-ControlNet
|
||||||
|
from diffsynth import load_state_dict, hash_state_dict_keys
|
||||||
|
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
|
||||||
|
controlnet = QwenImageBlockWiseControlNet().to(dtype=torch.bfloat16, device="cuda")
|
||||||
|
controlnet.init_weight()
|
||||||
|
state_dict_controlnet = controlnet.state_dict()
|
||||||
|
|
||||||
|
print(hash_state_dict_keys(state_dict_controlnet))
|
||||||
|
save_file(state_dict_controlnet, "models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors")
|
||||||
@@ -118,7 +118,7 @@ if __name__ == "__main__":
|
|||||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
||||||
)
|
)
|
||||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=0.000001)
|
||||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||||
launch_training_task(
|
launch_training_task(
|
||||||
dataset, model, model_logger, optimizer, scheduler,
|
dataset, model, model_logger, optimizer, scheduler,
|
||||||
|
|||||||
@@ -0,0 +1,38 @@
|
|||||||
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
|
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||||
|
from diffsynth import load_state_dict
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth.controlnets.processors import Annotator
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
pipe = QwenImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(path="models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict = load_state_dict("models/train/Qwen-Image-BlockWiseControlNet_full_lr1e-3_wd1e-6/step-26000.safetensors")
|
||||||
|
pipe.blockwise_controlnet.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||||
|
image = Image.open("test_image.jpg").convert("RGB").resize((1024, 1024))
|
||||||
|
canny_image = Annotator("canny")(image)
|
||||||
|
canny_image.save("canny_image_test.jpg")
|
||||||
|
|
||||||
|
controlnet_input = ControlNetInput(
|
||||||
|
image=canny_image,
|
||||||
|
scale=1.0,
|
||||||
|
processor_id="canny",
|
||||||
|
)
|
||||||
|
|
||||||
|
for seed in range(100, 200):
|
||||||
|
image = pipe(prompt, seed=seed, height=1024, width=1024, controlnet_inputs=[controlnet_input], num_inference_steps=30, cfg_scale=4.0)
|
||||||
|
image.save(f"test_image_controlnet_step2k_1_{seed}.jpg")
|
||||||
Reference in New Issue
Block a user