From a8cb4a21d151ec46b20e8f5c799d4a2702d12fc1 Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Thu, 12 Sep 2024 16:01:27 +0800 Subject: [PATCH] align flux lora format (#204) --- diffsynth/models/lora.py | 47 ++++++++++++++++++++++++++ diffsynth/trainers/text_to_image.py | 8 ++++- examples/train/README.md | 7 ++-- examples/train/flux/train_flux_lora.py | 15 ++++++-- 4 files changed, 71 insertions(+), 6 deletions(-) diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index 52a82fa..bae08f5 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -248,5 +248,52 @@ class GeneralLoRAFromPeft: return None +class FluxLoRAConverter: + def __init__(self): + pass + + def align_to_opensource_format(self, state_dict, alpha=1.0): + prefix_rename_dict = { + "single_blocks": "lora_unet_single_blocks", + "blocks": "lora_unet_double_blocks", + } + middle_rename_dict = { + "norm.linear": "modulation_lin", + "to_qkv_mlp": "linear1", + "proj_out": "linear2", + + "norm1_a.linear": "img_mod_lin", + "norm1_b.linear": "txt_mod_lin", + "attn.a_to_qkv": "img_attn_qkv", + "attn.b_to_qkv": "txt_attn_qkv", + "attn.a_to_out": "img_attn_proj", + "attn.b_to_out": "txt_attn_proj", + "ff_a.0": "img_mlp_0", + "ff_a.2": "img_mlp_2", + "ff_b.0": "txt_mlp_0", + "ff_b.2": "txt_mlp_2", + } + suffix_rename_dict = { + "lora_B.weight": "lora_up.weight", + "lora_A.weight": "lora_down.weight", + } + state_dict_ = {} + for name, param in state_dict.items(): + names = name.split(".") + if names[-2] != "lora_A" and names[-2] != "lora_B": + names.pop(-2) + prefix = names[0] + middle = ".".join(names[2:-2]) + suffix = ".".join(names[-2:]) + block_id = names[1] + if middle not in middle_rename_dict: + continue + rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix] + state_dict_[rename] = param + if rename.endswith("lora_up.weight"): + state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0] + return state_dict_ + + def get_lora_loaders(): return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft(), FluxLoRAFromCivitai()] diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index d8fbd94..50132a5 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -11,11 +11,13 @@ class LightningModelForT2ILoRA(pl.LightningModule): self, learning_rate=1e-4, use_gradient_checkpointing=True, + state_dict_converter=None, ): super().__init__() # Set parameters self.learning_rate = learning_rate self.use_gradient_checkpointing = use_gradient_checkpointing + self.state_dict_converter = state_dict_converter def load_models(self): @@ -83,9 +85,13 @@ class LightningModelForT2ILoRA(pl.LightningModule): trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) state_dict = self.pipe.denoising_model().state_dict() + lora_state_dict = {} for name, param in state_dict.items(): if name in trainable_param_names: - checkpoint[name] = param + lora_state_dict[name] = param + if self.state_dict_converter is not None: + lora_state_dict = self.state_dict_converter(lora_state_dict) + checkpoint.update(lora_state_dict) diff --git a/examples/train/README.md b/examples/train/README.md index bf3f777..50c2c1a 100644 --- a/examples/train/README.md +++ b/examples/train/README.md @@ -142,9 +142,12 @@ CUDA_VISIBLE_DEVICES="0" python examples/train/flux/train_flux_lora.py \ --learning_rate 1e-4 \ --lora_rank 4 \ --lora_alpha 4 \ - --use_gradient_checkpointing + --use_gradient_checkpointing \ + --align_to_opensource_format ``` +**`--align_to_opensource_format` means that this script will export the LoRA weights in the opensource format. This format can be loaded in both DiffSynth-Studio and other codebases.** + For more information about the parameters, please use `python examples/train/flux/train_flux_lora.py -h` to see the details. After training, use `model_manager.load_lora` to load the LoRA for inference. @@ -165,7 +168,7 @@ pipe = SDXLImagePipeline.from_model_manager(model_manager) torch.manual_seed(0) image = pipe( - prompt=prompt, + prompt="a dog is jumping, flowers around the dog, the background is mountains and clouds", num_inference_steps=30, embedded_guidance=3.5 ) image.save("image_with_lora.jpg") diff --git a/examples/train/flux/train_flux_lora.py b/examples/train/flux/train_flux_lora.py index 3c52352..65496ec 100644 --- a/examples/train/flux/train_flux_lora.py +++ b/examples/train/flux/train_flux_lora.py @@ -1,5 +1,6 @@ from diffsynth import ModelManager, FluxImagePipeline from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task +from diffsynth.models.lora import FluxLoRAConverter import torch, os, argparse os.environ["TOKENIZERS_PARALLELISM"] = "True" @@ -9,9 +10,10 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out" + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", + state_dict_converter=None, ): - super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing) + super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter) # Load models model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device) model_manager.load_models(pretrained_weights) @@ -58,6 +60,12 @@ def parse_args(): default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", help="Layers with LoRA modules.", ) + parser.add_argument( + "--align_to_opensource_format", + default=False, + action="store_true", + help="Whether to export lora files aligned with other opensource format.", + ) parser = add_general_parsers(parser) args = parser.parse_args() return args @@ -72,6 +80,7 @@ if __name__ == '__main__': use_gradient_checkpointing=args.use_gradient_checkpointing, lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, - lora_target_modules=args.lora_target_modules + lora_target_modules=args.lora_target_modules, + state_dict_converter=FluxLoRAConverter().align_to_opensource_format if args.align_to_opensource_format else None, ) launch_training_task(model, args)