diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index 6dc24e6..3177474 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -18,6 +18,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): self.learning_rate = learning_rate self.use_gradient_checkpointing = use_gradient_checkpointing self.state_dict_converter = state_dict_converter + self.lora_alpha = None def load_models(self): @@ -34,6 +35,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian"): # Add LoRA to UNet + self.lora_alpha = lora_alpha if init_lora_weights == "kaiming": init_lora_weights = True @@ -94,7 +96,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): if name in trainable_param_names: lora_state_dict[name] = param if self.state_dict_converter is not None: - lora_state_dict = self.state_dict_converter(lora_state_dict) + lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha) checkpoint.update(lora_state_dict)