From c5077f4ebc659292dc152af085282adb42ef204e Mon Sep 17 00:00:00 2001 From: josc146 Date: Thu, 14 Mar 2024 12:25:09 +0800 Subject: [PATCH] fix v6 lora (https://github.com/JL-er/RWKV-LORA/commit/c03cdbbdafa498a7d65da37bf54a4228eff79132) --- finetune/lora/v6/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/finetune/lora/v6/train.py b/finetune/lora/v6/train.py index e0079ea..d41f37d 100644 --- a/finetune/lora/v6/train.py +++ b/finetune/lora/v6/train.py @@ -314,8 +314,6 @@ if __name__ == "__main__": from src.model import RWKV, LORA_CONFIG, LoraLinear - model = RWKV(args) - if args.lora: assert args.lora_r > 0, "LoRA should have its `r` > 0" LORA_CONFIG["r"] = args.lora_r @@ -324,6 +322,9 @@ if __name__ == "__main__": LORA_CONFIG["parts"] = set(str(args.lora_parts).split(",")) enable_time_finetune = "time" in LORA_CONFIG["parts"] enable_ln_finetune = "ln" in LORA_CONFIG["parts"] + model = RWKV(args) + + if args.lora: model.requires_grad_(False) for name, module in model.named_modules():