From 540c0369885d346abc255c383243c40c54df5fb7 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 14 Oct 2024 18:57:54 +0800 Subject: [PATCH] add alpha to lora converter --- diffsynth/trainers/text_to_image.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index 6dc24e6..3177474 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -18,6 +18,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): self.learning_rate = learning_rate self.use_gradient_checkpointing = use_gradient_checkpointing self.state_dict_converter = state_dict_converter + self.lora_alpha = None def load_models(self): @@ -34,6 +35,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian"): # Add LoRA to UNet + self.lora_alpha = lora_alpha if init_lora_weights == "kaiming": init_lora_weights = True @@ -94,7 +96,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): if name in trainable_param_names: lora_state_dict[name] = param if self.state_dict_converter is not None: - lora_state_dict = self.state_dict_converter(lora_state_dict) + lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha) checkpoint.update(lora_state_dict)