From 89c4e3bdb6e564eb35402673fdf2464fc93d5c3b Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 8 Aug 2025 18:55:13 +0800 Subject: [PATCH] lora-fix --- diffsynth/trainers/utils.py | 1 - examples/qwen_image/README.md | 2 -- examples/qwen_image/README_zh.md | 2 -- examples/qwen_image/model_training/train.py | 7 +------ 4 files changed, 1 insertion(+), 11 deletions(-) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 65e4e50..ac13b2e 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -531,7 +531,6 @@ def qwen_image_parser(): parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") - parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.") parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") diff --git a/examples/qwen_image/README.md b/examples/qwen_image/README.md index aba8120..d487bb7 100644 --- a/examples/qwen_image/README.md +++ b/examples/qwen_image/README.md @@ -247,8 +247,6 @@ The script includes the following parameters: * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing. * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory. * `--gradient_accumulation_steps`: Number of gradient accumulation steps. -* Others - * `--align_to_opensource_format`: Whether to align DiT LoRA format with open-source version. Only works for LoRA training. In addition, the training framework is built on [`accelerate`](https://huggingface.co/docs/accelerate/index). Run `accelerate config` before training to set GPU-related settings. For some training tasks (e.g., full training of 20B model), we provide suggested `accelerate` config files. Check the corresponding training script for details. diff --git a/examples/qwen_image/README_zh.md b/examples/qwen_image/README_zh.md index 1259dc3..1cc208c 100644 --- a/examples/qwen_image/README_zh.md +++ b/examples/qwen_image/README_zh.md @@ -247,8 +247,6 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。 * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 * `--gradient_accumulation_steps`: 梯度累积步数。 -* 其他 - * `--align_to_opensource_format`: 是否将 DiT LoRA 的格式与开源版本对齐,仅对 LoRA 训练生效。 此外,训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,在开始训练前运行 `accelerate config` 可配置 GPU 的相关参数。对于部分模型训练(例如 20B 模型的全量训练)脚本,我们提供了建议的 `accelerate` 配置文件,可在对应的训练脚本中查看。 diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 4e8cf48..d8f6343 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -1,7 +1,6 @@ import torch, os, json from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser -from diffsynth.models.lora import QwenImageLoRAConverter os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -106,11 +105,7 @@ if __name__ == "__main__": use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, ) - model_logger = ModelLogger( - args.output_path, - remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, - state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x, - ) + model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) launch_training_task(