support resume training

This commit is contained in:
Artiprocher
2024-12-16 11:08:14 +08:00
parent 7c0520d029
commit 8c2671ce40
7 changed files with 77 additions and 13 deletions

View File

@@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA):
self,
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian",
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None,
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing)
# Load models
@@ -24,7 +24,14 @@ class LightningModel(LightningModelForT2ILoRA):
model_manager.load_lora(path)
self.freeze_parameters()
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights)
self.add_lora_to_model(
self.pipe.denoising_model(),
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_target_modules=lora_target_modules,
init_lora_weights=init_lora_weights,
pretrained_lora_path=pretrained_lora_path,
)
def parse_args():
@@ -70,6 +77,7 @@ if __name__ == '__main__':
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=args.init_lora_weights,
pretrained_lora_path=args.pretrained_lora_path,
lora_target_modules=args.lora_target_modules
)
launch_training_task(model, args)