This commit is contained in:
josc146 2024-03-14 12:25:09 +08:00
parent acf5d02104
commit c5077f4ebc

View File

@ -314,8 +314,6 @@ if __name__ == "__main__":
from src.model import RWKV, LORA_CONFIG, LoraLinear
model = RWKV(args)
if args.lora:
assert args.lora_r > 0, "LoRA should have its `r` > 0"
LORA_CONFIG["r"] = args.lora_r
@ -324,6 +322,9 @@ if __name__ == "__main__":
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
enable_time_finetune = "time" in LORA_CONFIG["parts"]
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
model = RWKV(args)
if args.lora:
model.requires_grad_(False)
for name, module in model.named_modules():