diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 47cfbea..7e05b5a 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -76,6 +76,7 @@ from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_vae import QwenImageVAE from ..models.qwen_image_controlnet import QwenImageControlNet +from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -169,6 +170,7 @@ model_loader_configs = [ (None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"), (None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"), (None, "be2500a62936a43d5367a70ea001e25d", ["qwen_image_controlnet"], [QwenImageControlNet], "civitai"), + (None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/qwen_image_controlnet.py b/diffsynth/models/qwen_image_controlnet.py index 7c18011..b86dda5 100644 --- a/diffsynth/models/qwen_image_controlnet.py +++ b/diffsynth/models/qwen_image_controlnet.py @@ -93,3 +93,67 @@ class QwenImageControlNetStateDictConverter(): def from_civitai(self, state_dict): return state_dict + + +class BlockWiseControlBlock(torch.nn.Module): + # [linear, gelu, linear] + def __init__(self, dim: int = 3072): + super().__init__() + self.x_rms = RMSNorm(dim, eps=1e-6) + self.y_rms = RMSNorm(dim, eps=1e-6) + self.input_proj = nn.Linear(dim, dim) + self.act = nn.GELU() + self.output_proj = nn.Linear(dim, dim) + + def forward(self, x, y): + x, y = self.x_rms(x), self.y_rms(y) + x = self.input_proj(x + y) + x = self.act(x) + x = self.output_proj(x) + return x + + def init_weights(self): + # zero initialize output_proj + nn.init.zeros_(self.output_proj.weight) + nn.init.zeros_(self.output_proj.bias) + + +class QwenImageBlockWiseControlNet(torch.nn.Module): + def __init__( + self, + num_layers: int = 60, + in_dim: int = 64, + dim: int = 3072, + ): + super().__init__() + self.img_in = nn.Linear(in_dim, dim) + self.controlnet_blocks = nn.ModuleList( + [ + BlockWiseControlBlock(dim) + for _ in range(num_layers) + ] + ) + + def init_weight(self): + nn.init.zeros_(self.img_in.weight) + nn.init.zeros_(self.img_in.bias) + for block in self.controlnet_blocks: + block.init_weights() + + def process_controlnet_conditioning(self, controlnet_conditioning): + return self.img_in(controlnet_conditioning) + + def blockwise_forward(self, img, controlnet_conditioning, block_id): + return self.controlnet_blocks[block_id](img, controlnet_conditioning) + + @staticmethod + def state_dict_converter(): + return QwenImageBlockWiseControlNetStateDictConverter() + + +class QwenImageBlockWiseControlNetStateDictConverter(): + def __init__(self): + pass + + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 2289e1f..ccbb321 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -10,7 +10,7 @@ from ..models import ModelManager, load_state_dict from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_vae import QwenImageVAE -from ..models.qwen_image_controlnet import QwenImageControlNet +from ..models.qwen_image_controlnet import QwenImageControlNet, QwenImageBlockWiseControlNet from ..schedulers import FlowMatchScheduler from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit from ..lora import GeneralLoRALoader @@ -69,7 +69,7 @@ class QwenImagePipeline(BasePipeline): self.controlnet: QwenImageMultiControlNet = None self.tokenizer: Qwen2Tokenizer = None self.unit_runner = PipelineUnitRunner() - self.in_iteration_models = ("dit", "controlnet") + self.in_iteration_models = ("dit", "controlnet", "blockwise_controlnet") self.units = [ QwenImageUnit_ShapeChecker(), QwenImageUnit_NoiseInitializer(), @@ -226,6 +226,7 @@ class QwenImagePipeline(BasePipeline): pipe.dit = model_manager.fetch_model("qwen_image_dit") pipe.vae = model_manager.fetch_model("qwen_image_vae") pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all")) + pipe.blockwise_controlnet = model_manager.fetch_model("qwen_image_blockwise_controlnet") if tokenizer_config is not None and pipe.text_encoder is not None: tokenizer_config.download_if_necessary() from transformers import Qwen2Tokenizer @@ -499,6 +500,7 @@ class QwenImageUnit_ControlNet(PipelineUnit): def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride): if controlnet_inputs is None: return {} + return_key = "blockwise_controlnet_conditioning" if pipe.blockwise_controlnet is not None else "controlnet_conditionings" pipe.load_models_to_device(self.onload_model_names) conditionings = [] for controlnet_input in controlnet_inputs: @@ -512,12 +514,13 @@ class QwenImageUnit_ControlNet(PipelineUnit): if controlnet_input.inpaint_mask is not None: image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask) conditionings.append(image) - return {"controlnet_conditionings": conditionings} + return {return_key: conditionings} def model_fn_qwen_image( dit: QwenImageDiT = None, controlnet: QwenImageMultiControlNet = None, + blockwise_controlnet: QwenImageBlockWiseControlNet = None, latents=None, timestep=None, prompt_emb=None, @@ -526,6 +529,7 @@ def model_fn_qwen_image( width=None, controlnet_inputs=None, controlnet_conditionings=None, + blockwise_controlnet_conditioning=None, progress_id=0, num_inference_steps=1, entity_prompt_emb=None, @@ -572,6 +576,13 @@ def model_fn_qwen_image( image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device) attention_mask = None + if blockwise_controlnet_conditioning is not None: + blockwise_controlnet_conditioning = rearrange( + blockwise_controlnet_conditioning[0], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2 + ) + blockwise_controlnet_conditioning = blockwise_controlnet.process_controlnet_conditioning(blockwise_controlnet_conditioning) + + # blockwise_controlnet_conditioning = for block_id, block in enumerate(dit.transformer_blocks): text, image = gradient_checkpoint_forward( block, @@ -584,9 +595,11 @@ def model_fn_qwen_image( attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention, ) - if controlnet_inputs is not None: + if blockwise_controlnet is not None: + image = image + blockwise_controlnet.blockwise_forward(image, blockwise_controlnet_conditioning, block_id) + if controlnet_conditionings is not None: image = image + res_stack[block_id] - + image = dit.norm_out(image, conditioning) image = dit.proj_out(image) diff --git a/examples/qwen_image/model_training/full/Qwen-Image-BlockWiseControlNet.sh b/examples/qwen_image/model_training/full/Qwen-Image-BlockWiseControlNet.sh new file mode 100644 index 0000000..22a499e --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image-BlockWiseControlNet.sh @@ -0,0 +1,36 @@ +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path "" \ + --dataset_metadata_path data/t2i_dataset_annotations/blip3o/blip3o_control_images_train_for_diffsynth.jsonl \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_paths '[ + [ + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors", + "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors" + ], + [ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" + ], + "models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors", + "models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors" +]' \ + --learning_rate 1e-3 \ + --num_epochs 1000000 \ + --remove_prefix_in_ckpt "pipe.blockwise_controlnet." \ + --output_path "./models/train/Qwen-Image-BlockWiseControlNet_full_lr1e-3_wd1e-6" \ + --trainable_models "blockwise_controlnet" \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --save_steps 2000 diff --git a/examples/qwen_image/model_training/full/accelerate_config.yaml b/examples/qwen_image/model_training/full/accelerate_config.yaml new file mode 100644 index 0000000..83280f7 --- /dev/null +++ b/examples/qwen_image/model_training/full/accelerate_config.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/qwen_image/model_training/full/others/initialize_blockwise_controlnet.py b/examples/qwen_image/model_training/full/others/initialize_blockwise_controlnet.py new file mode 100644 index 0000000..9625a20 --- /dev/null +++ b/examples/qwen_image/model_training/full/others/initialize_blockwise_controlnet.py @@ -0,0 +1,13 @@ +# This script is for initializing a Qwen-Image-ControlNet +from diffsynth import load_state_dict, hash_state_dict_keys +from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet +import torch +from safetensors.torch import save_file + + +controlnet = QwenImageBlockWiseControlNet().to(dtype=torch.bfloat16, device="cuda") +controlnet.init_weight() +state_dict_controlnet = controlnet.state_dict() + +print(hash_state_dict_keys(state_dict_controlnet)) +save_file(state_dict_controlnet, "models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors") diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 04c9d29..61a9b78 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -118,7 +118,7 @@ if __name__ == "__main__": remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x, ) - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=0.000001) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) launch_training_task( dataset, model, model_logger, optimizer, scheduler, diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image-BlockWise-Controlnet.py b/examples/qwen_image/model_training/validate_full/Qwen-Image-BlockWise-Controlnet.py new file mode 100644 index 0000000..72b1ae7 --- /dev/null +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image-BlockWise-Controlnet.py @@ -0,0 +1,38 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +import torch +from PIL import Image +from diffsynth.controlnets.processors import Annotator +import os + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(path="models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) + +state_dict = load_state_dict("models/train/Qwen-Image-BlockWiseControlNet_full_lr1e-3_wd1e-6/step-26000.safetensors") +pipe.blockwise_controlnet.load_state_dict(state_dict) + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = Image.open("test_image.jpg").convert("RGB").resize((1024, 1024)) +canny_image = Annotator("canny")(image) +canny_image.save("canny_image_test.jpg") + +controlnet_input = ControlNetInput( + image=canny_image, + scale=1.0, + processor_id="canny", +) + +for seed in range(100, 200): + image = pipe(prompt, seed=seed, height=1024, width=1024, controlnet_inputs=[controlnet_input], num_inference_steps=30, cfg_scale=4.0) + image.save(f"test_image_controlnet_step2k_1_{seed}.jpg")