lora_checkpoint & weight_decay & qwen_image_controlnet_train

This commit is contained in:
mi804
2025-08-14 13:50:04 +08:00
parent fa36739f01
commit 49f9a11eb3
10 changed files with 70 additions and 8 deletions

View File

@@ -237,6 +237,7 @@ The script includes the following parameters:
* `--tokenizer_path`: Tokenizer path. Leave empty to auto-download.
* Training
* `--learning_rate`: Learning rate.
* `--weight_decay`: Weight decay.
* `--num_epochs`: Number of epochs.
* `--output_path`: Save path.
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
@@ -247,6 +248,7 @@ The script includes the following parameters:
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.
* Extra Model Inputs
* `--extra_inputs`: Extra model inputs, separated by commas.
* VRAM Management

View File

@@ -237,6 +237,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
* `--tokenizer_path`: tokenizer 路径,留空将会自动下载。
* 训练
* `--learning_rate`: 学习率。
* `--weight_decay`:权重衰减大小。
* `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
@@ -247,6 +248,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪一层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径LoRA 将从此检查点加载。
* 额外模型输入
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
* 显存管理

View File

@@ -1,5 +1,7 @@
import torch, os, json
from diffsynth import load_state_dict
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.pipelines.flux_image_new import ControlNetInput
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -11,7 +13,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
model_paths=None, model_id_with_origin_paths=None,
tokenizer_path=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
@@ -43,6 +45,12 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(self.pipe, lora_base_model, model)
# Store other configs
@@ -72,8 +80,18 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
}
# Extra inputs
controlnet_input = {}
for extra_input in self.extra_inputs:
inputs_shared[extra_input] = data[extra_input]
if extra_input.startswith("blockwise_controlnet_"):
controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input]
elif extra_input.startswith("controlnet_"):
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
else:
inputs_shared[extra_input] = data[extra_input]
if len(controlnet_input) > 0:
controlnet_key = "blockwise_controlnet_inputs" if "blockwise_controlnet_image" in self.extra_inputs else "controlnet_inputs"
inputs_shared[controlnet_key] = [ControlNetInput(**controlnet_input)]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units:
@@ -101,12 +119,13 @@ if __name__ == "__main__":
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
)
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
launch_training_task(
dataset, model, model_logger, optimizer, scheduler,