diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index e328593..6555cfd 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -75,6 +75,7 @@ from ..models.nexus_gen import NexusGenAutoregressiveModel from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_vae import QwenImageVAE +from ..models.qwen_image_controlnet import QwenImageControlNet model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -167,6 +168,7 @@ model_loader_configs = [ (None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"), (None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"), (None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"), + (None, "6834bf9ef1a6723291d82719fb02e953", ["qwen_image_controlnet"], [QwenImageControlNet], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/qwen_image_controlnet.py b/diffsynth/models/qwen_image_controlnet.py new file mode 100644 index 0000000..590c84b --- /dev/null +++ b/diffsynth/models/qwen_image_controlnet.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +from .qwen_image_dit import QwenEmbedRope, QwenImageTransformerBlock +from ..vram_management import gradient_checkpoint_forward +from einops import rearrange +from .sd3_dit import TimestepEmbeddings, RMSNorm + + +class QwenImageControlNet(torch.nn.Module): + def __init__( + self, + num_layers: int = 60, + num_controlnet_layers: int = 6, + ): + super().__init__() + + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True) + + self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True) + self.txt_norm = RMSNorm(3584, eps=1e-6) + + self.img_in = nn.Linear(64 * 2, 3072) + self.txt_in = nn.Linear(3584, 3072) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=3072, + num_attention_heads=24, + attention_head_dim=128, + ) + for _ in range(num_controlnet_layers) + ] + ) + self.alpha = torch.nn.Parameter(torch.zeros((num_layers,))) + self.num_layers = num_layers + self.num_controlnet_layers = num_controlnet_layers + self.align_map = {i: i // (num_layers // num_controlnet_layers) for i in range(num_layers)} + + + def forward( + self, + latents=None, + timestep=None, + prompt_emb=None, + prompt_emb_mask=None, + height=None, + width=None, + controlnet_conditioning=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, + ): + img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] + txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() + + image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2) + controlnet_conditioning = rearrange(controlnet_conditioning, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2) + image = torch.concat([image, controlnet_conditioning], dim=-1) + + image = self.img_in(image) + text = self.txt_in(self.txt_norm(prompt_emb)) + + conditioning = self.time_text_embed(timestep, image.dtype) + + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + + outputs = [] + for block in self.transformer_blocks: + text, image = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + ) + outputs.append(image) + + alpha = self.alpha.to(dtype=image.dtype, device=image.device) + outputs_aligned = [outputs[self.align_map[i]] * alpha[i] for i in range(self.num_layers)] + return outputs_aligned + + @staticmethod + def state_dict_converter(): + return QwenImageControlNetStateDictConverter() + + + +class QwenImageControlNetStateDictConverter(): + def __init__(self): + pass + + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index deccd62..52f4466 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -4,19 +4,54 @@ from typing import Union from PIL import Image from tqdm import tqdm from einops import rearrange +import numpy as np from ..models import ModelManager, load_state_dict from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_vae import QwenImageVAE +from ..models.qwen_image_controlnet import QwenImageControlNet from ..schedulers import FlowMatchScheduler from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit from ..lora import GeneralLoRALoader +from .flux_image_new import ControlNetInput from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear +class QwenImageMultiControlNet(torch.nn.Module): + def __init__(self, models: list[QwenImageControlNet]): + super().__init__() + if not isinstance(models, list): + models = [models] + self.models = torch.nn.ModuleList(models) + + def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs): + model = self.models[controlnet_input.controlnet_id] + res_stack = model( + controlnet_conditioning=conditioning, + processor_id=controlnet_input.processor_id, + **kwargs + ) + res_stack = [res * controlnet_input.scale for res in res_stack] + return res_stack + + def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs): + res_stack = None + for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): + progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1) + if progress > controlnet_input.start or progress < controlnet_input.end: + continue + res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs) + if res_stack is None: + res_stack = res_stack_ + else: + res_stack = [i + j for i, j in zip(res_stack, res_stack_)] + return res_stack + + + class QwenImagePipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.bfloat16): @@ -30,14 +65,16 @@ class QwenImagePipeline(BasePipeline): self.text_encoder: QwenImageTextEncoder = None self.dit: QwenImageDiT = None self.vae: QwenImageVAE = None + self.controlnet: QwenImageMultiControlNet = None self.tokenizer: Qwen2Tokenizer = None self.unit_runner = PipelineUnitRunner() - self.in_iteration_models = ("dit",) + self.in_iteration_models = ("dit", "controlnet") self.units = [ QwenImageUnit_ShapeChecker(), QwenImageUnit_NoiseInitializer(), QwenImageUnit_InputImageEmbedder(), QwenImageUnit_PromptEmbedder(), + QwenImageUnit_ControlNet(), ] self.model_fn = model_fn_qwen_image @@ -165,6 +202,7 @@ class QwenImagePipeline(BasePipeline): pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder") pipe.dit = model_manager.fetch_model("qwen_image_dit") pipe.vae = model_manager.fetch_model("qwen_image_vae") + pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all")) if tokenizer_config is not None and pipe.text_encoder is not None: tokenizer_config.download_if_necessary() from transformers import Qwen2Tokenizer @@ -190,6 +228,8 @@ class QwenImagePipeline(BasePipeline): rand_device: str = "cpu", # Steps num_inference_steps: int = 30, + # ControlNet + controlnet_inputs: list[ControlNetInput] = None, # Tile tiled: bool = False, tile_size: int = 128, @@ -212,6 +252,8 @@ class QwenImagePipeline(BasePipeline): "input_image": input_image, "denoising_strength": denoising_strength, "height": height, "width": width, "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + "controlnet_inputs": controlnet_inputs, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, } for unit in self.units: @@ -323,18 +365,82 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit): +class QwenImageUnit_ControlNet(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def apply_controlnet_mask_on_latents(self, pipe, latents, mask): + mask = (pipe.preprocess_image(mask) + 1) / 2 + mask = mask.mean(dim=1, keepdim=True) + mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:]) + latents = torch.concat([latents, mask], dim=1) + return latents + + def apply_controlnet_mask_on_image(self, pipe, image, mask): + mask = mask.resize(image.size) + mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu() + image = np.array(image) + image[mask > 0] = 0 + image = Image.fromarray(image) + return image + + def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride): + if controlnet_inputs is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + conditionings = [] + for controlnet_input in controlnet_inputs: + image = controlnet_input.image + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask) + + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask) + conditionings.append(image) + return {"controlnet_conditionings": conditionings} + + + def model_fn_qwen_image( dit: QwenImageDiT = None, + controlnet: QwenImageMultiControlNet = None, latents=None, timestep=None, prompt_emb=None, prompt_emb_mask=None, height=None, width=None, + controlnet_inputs=None, + controlnet_conditionings=None, + progress_id=0, + num_inference_steps=1, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs ): + # ControlNet + if controlnet_conditionings is not None: + controlnet_extra_kwargs = { + "latents": latents, + "timestep": timestep, + "prompt_emb": prompt_emb, + "prompt_emb_mask": prompt_emb_mask, + "height": height, + "width": width, + "use_gradient_checkpointing": use_gradient_checkpointing, + "use_gradient_checkpointing_offload": use_gradient_checkpointing_offload, + } + res_stack = controlnet( + controlnet_conditionings, controlnet_inputs, progress_id, num_inference_steps, + **controlnet_extra_kwargs + ) + img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() timestep = timestep / 1000 @@ -346,7 +452,7 @@ def model_fn_qwen_image( conditioning = dit.time_text_embed(timestep, image.dtype) image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device) - for block in dit.transformer_blocks: + for block_id, block in enumerate(dit.transformer_blocks): text, image = gradient_checkpoint_forward( block, use_gradient_checkpointing, @@ -356,6 +462,8 @@ def model_fn_qwen_image( temb=conditioning, image_rotary_emb=image_rotary_emb, ) + if controlnet_inputs is not None: + image = image + res_stack[block_id] image = dit.norm_out(image, conditioning) image = dit.proj_out(image) diff --git a/examples/qwen_image/model_training/full/Qwen-Image-ControlNet.sh b/examples/qwen_image/model_training/full/Qwen-Image-ControlNet.sh new file mode 100644 index 0000000..f5b948e --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image-ControlNet.sh @@ -0,0 +1,34 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 8000 \ + --model_paths '[ + [ + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors" + ], + [ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" + ], + "models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors", + "models/controlnet.safetensors" +]' \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.controlnet.models.0." \ + --output_path "./models/train/Qwen-Image-ControlNet_full" \ + --trainable_models "controlnet" \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing diff --git a/examples/qwen_image/model_training/full/others/initialize_controlnet.py b/examples/qwen_image/model_training/full/others/initialize_controlnet.py new file mode 100644 index 0000000..a8eadce --- /dev/null +++ b/examples/qwen_image/model_training/full/others/initialize_controlnet.py @@ -0,0 +1,33 @@ +# This script is for initializing a Qwen-Image-ControlNet +from diffsynth import load_state_dict, hash_state_dict_keys +from diffsynth.pipelines.qwen_image import QwenImageControlNet +import torch +from safetensors.torch import save_file + + +state_dict_dit = {} +for i in range(1, 10): + state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda")) + +controlnet = QwenImageControlNet().to(dtype=torch.bfloat16, device="cuda") +state_dict_controlnet = controlnet.state_dict() + +state_dict_init = {} +for k in state_dict_controlnet: + if k in state_dict_dit: + if state_dict_dit[k].shape == state_dict_controlnet[k].shape: + state_dict_init[k] = state_dict_dit[k] + elif k == "img_in.weight": + state_dict_init[k] = torch.concat( + [ + state_dict_dit[k], + state_dict_dit[k], + ], + dim=-1 + ) + elif k == "alpha": + state_dict_init[k] = torch.zeros_like(state_dict_controlnet[k]) +controlnet.load_state_dict(state_dict_init) + +print(hash_state_dict_keys(state_dict_controlnet)) +save_file(state_dict_controlnet, "models/controlnet.safetensors") diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 48d2d1a..ed7b5f0 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -1,5 +1,5 @@ import torch, os, json -from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -72,8 +72,14 @@ class QwenImageTrainingModule(DiffusionTrainingModule): } # Extra inputs + controlnet_input = {} for extra_input in self.extra_inputs: - inputs_shared[extra_input] = data[extra_input] + if extra_input.startswith("controlnet_"): + controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input] + else: + inputs_shared[extra_input] = data[extra_input] + if len(controlnet_input) > 0: + inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)] # Pipeline units will automatically process the input parameters. for unit in self.pipe.units: