support resume from opensource format

This commit is contained in:
Artiprocher
2024-12-16 12:25:05 +08:00
parent 8c2671ce40
commit 919d399fdb
3 changed files with 52 additions and 2 deletions

View File

@@ -34,7 +34,7 @@ class LightningModelForT2ILoRA(pl.LightningModule):
self.pipe.denoising_model().train()
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None):
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, state_dict_converter=None):
# Add LoRA to UNet
self.lora_alpha = lora_alpha
if init_lora_weights == "kaiming":
@@ -55,6 +55,8 @@ class LightningModelForT2ILoRA(pl.LightningModule):
# Lora pretrained lora weights
if pretrained_lora_path is not None:
state_dict = load_state_dict(pretrained_lora_path)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
all_keys = [i for i, _ in model.named_parameters()]
num_updated_keys = len(all_keys) - len(missing_keys)