From a403cb04f305e3504f02d268be964b17004d1153 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 21 Oct 2024 14:03:58 +0800 Subject: [PATCH] support preset lora --- examples/train/flux/train_flux_lora.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/train/flux/train_flux_lora.py b/examples/train/flux/train_flux_lora.py index ead67d5..0bf118f 100644 --- a/examples/train/flux/train_flux_lora.py +++ b/examples/train/flux/train_flux_lora.py @@ -8,7 +8,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "True" class LightningModel(LightningModelForT2ILoRA): def __init__( self, - torch_dtype=torch.float16, pretrained_weights=[], + torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None, learning_rate=1e-4, use_gradient_checkpointing=True, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", state_dict_converter=None, quantize = None @@ -21,6 +21,8 @@ class LightningModel(LightningModelForT2ILoRA): else: model_manager.load_models(pretrained_weights[1:]) model_manager.load_model(pretrained_weights[0], torch_dtype=quantize) + if preset_lora_path is not None: + model_manager.load_lora(preset_lora_path) self.pipe = FluxImagePipeline.from_model_manager(model_manager) @@ -82,6 +84,12 @@ def parse_args(): choices=["float8_e4m3fn"], help="Whether to use quantization when training the model, and in which format.", ) + parser.add_argument( + "--preset_lora_path", + type=str, + default=None, + help="Preset LoRA path.", + ) parser = add_general_parsers(parser) args = parser.parse_args() return args @@ -92,6 +100,7 @@ if __name__ == '__main__': model = LightningModel( torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16), pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path], + preset_lora_path=args.preset_lora_path, learning_rate=args.learning_rate, use_gradient_checkpointing=args.use_gradient_checkpointing, lora_rank=args.lora_rank,