From 919d399fdba678a0c1c25877729106d5547f88ef Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 16 Dec 2024 12:25:05 +0800 Subject: [PATCH] support resume from opensource format --- diffsynth/models/lora.py | 47 ++++++++++++++++++++++++++ diffsynth/trainers/text_to_image.py | 4 ++- examples/train/flux/train_flux_lora.py | 3 +- 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index eebc4a2..d416864 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -306,6 +306,53 @@ class FluxLoRAConverter: state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0] return state_dict_ + @staticmethod + def align_to_diffsynth_format(state_dict): + rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", + } + def guess_block_id(name): + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + return None, None + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name) + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + return state_dict_ + def get_lora_loaders(): return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()] diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index 5e49c98..5f7c723 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -34,7 +34,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): self.pipe.denoising_model().train() - def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None): + def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, state_dict_converter=None): # Add LoRA to UNet self.lora_alpha = lora_alpha if init_lora_weights == "kaiming": @@ -55,6 +55,8 @@ class LightningModelForT2ILoRA(pl.LightningModule): # Lora pretrained lora weights if pretrained_lora_path is not None: state_dict = load_state_dict(pretrained_lora_path) + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) all_keys = [i for i, _ in model.named_parameters()] num_updated_keys = len(all_keys) - len(missing_keys) diff --git a/examples/train/flux/train_flux_lora.py b/examples/train/flux/train_flux_lora.py index 681d6ba..eb5539a 100644 --- a/examples/train/flux/train_flux_lora.py +++ b/examples/train/flux/train_flux_lora.py @@ -40,7 +40,8 @@ class LightningModel(LightningModelForT2ILoRA): lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights, - pretrained_lora_path=pretrained_lora_path + pretrained_lora_path=pretrained_lora_path, + state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format )