mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
qwen-image controlnet
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
|
||||
--data_file_keys "image,controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 80000 \
|
||||
--model_paths '[
|
||||
[
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
|
||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
|
||||
],
|
||||
[
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
],
|
||||
"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors",
|
||||
"models/controlnet.safetensors"
|
||||
]' \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1000000 \
|
||||
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
|
||||
--output_path "./models/train/Qwen-Image-ControlNet_full" \
|
||||
--trainable_models "controlnet" \
|
||||
--extra_inputs "controlnet_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--save_steps 100
|
||||
@@ -0,0 +1,34 @@
|
||||
# This script is for initializing a Qwen-Image-ControlNet
|
||||
from diffsynth import load_state_dict, hash_state_dict_keys
|
||||
from diffsynth.pipelines.qwen_image import QwenImageControlNet
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
state_dict_dit = {}
|
||||
for i in range(1, 10):
|
||||
state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda"))
|
||||
|
||||
controlnet = QwenImageControlNet().to(dtype=torch.bfloat16, device="cuda")
|
||||
state_dict_controlnet = controlnet.state_dict()
|
||||
|
||||
state_dict_init = {}
|
||||
for k in state_dict_controlnet:
|
||||
if k in state_dict_dit:
|
||||
if state_dict_dit[k].shape == state_dict_controlnet[k].shape:
|
||||
state_dict_init[k] = state_dict_dit[k]
|
||||
elif k == "img_in.weight":
|
||||
state_dict_init[k] = torch.concat(
|
||||
[
|
||||
state_dict_dit[k],
|
||||
state_dict_dit[k],
|
||||
],
|
||||
dim=-1
|
||||
)
|
||||
else:
|
||||
print("Zero Initialized:", k)
|
||||
state_dict_init[k] = torch.zeros_like(state_dict_controlnet[k])
|
||||
controlnet.load_state_dict(state_dict_init)
|
||||
|
||||
print(hash_state_dict_keys(state_dict_init))
|
||||
save_file(state_dict_init, "models/controlnet.safetensors")
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch, os, json
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
|
||||
from diffsynth.models.lora import QwenImageLoRAConverter
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -73,8 +73,15 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
}
|
||||
|
||||
# Extra inputs
|
||||
controlnet_input = {}
|
||||
for extra_input in self.extra_inputs:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
if extra_input.startswith("controlnet_"):
|
||||
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
if len(controlnet_input) > 0:
|
||||
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
|
||||
|
||||
# Pipeline units will automatically process the input parameters.
|
||||
for unit in self.pipe.units:
|
||||
|
||||
Reference in New Issue
Block a user