From 6f4e38276e2de7b4099c7afa57ebf52a4f262ff1 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 6 Aug 2025 15:41:22 +0800 Subject: [PATCH] remove default in qwen-image lora --- diffsynth/models/lora.py | 15 +++++++++++++++ examples/qwen_image/model_training/train.py | 2 ++ 2 files changed, 17 insertions(+) diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index 11b34e3..0278bb1 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -383,5 +383,20 @@ class WanLoRAConverter: return state_dict +class QwenImageLoRAConverter: + def __init__(self): + pass + + @staticmethod + def align_to_opensource_format(state_dict, **kwargs): + state_dict = {name.replace(".default.", "."): param for name, param in state_dict.items()} + return state_dict + + @staticmethod + def align_to_diffsynth_format(state_dict, **kwargs): + state_dict = {name.replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()} + return state_dict + + def get_lora_loaders(): return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()] diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 48d2d1a..17b234a 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -1,6 +1,7 @@ import torch, os, json from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser +from diffsynth.models.lora import QwenImageLoRAConverter os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -108,6 +109,7 @@ if __name__ == "__main__": model_logger = ModelLogger( args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x, ) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)