This commit is contained in:
josc146 2024-03-14 12:25:09 +08:00
parent acf5d02104
commit c5077f4ebc

View File

@ -314,8 +314,6 @@ if __name__ == "__main__":
from src.model import RWKV, LORA_CONFIG, LoraLinear from src.model import RWKV, LORA_CONFIG, LoraLinear
model = RWKV(args)
if args.lora: if args.lora:
assert args.lora_r > 0, "LoRA should have its `r` > 0" assert args.lora_r > 0, "LoRA should have its `r` > 0"
LORA_CONFIG["r"] = args.lora_r LORA_CONFIG["r"] = args.lora_r
@ -324,6 +322,9 @@ if __name__ == "__main__":
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(",")) LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
enable_time_finetune = "time" in LORA_CONFIG["parts"] enable_time_finetune = "time" in LORA_CONFIG["parts"]
enable_ln_finetune = "ln" in LORA_CONFIG["parts"] enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
model = RWKV(args)
if args.lora:
model.requires_grad_(False) model.requires_grad_(False)
for name, module in model.named_modules(): for name, module in model.named_modules():