fix v6 lora (c03cdbbdaf
)
This commit is contained in:
parent
acf5d02104
commit
c5077f4ebc
5
finetune/lora/v6/train.py
vendored
5
finetune/lora/v6/train.py
vendored
@ -314,8 +314,6 @@ if __name__ == "__main__":
|
||||
|
||||
from src.model import RWKV, LORA_CONFIG, LoraLinear
|
||||
|
||||
model = RWKV(args)
|
||||
|
||||
if args.lora:
|
||||
assert args.lora_r > 0, "LoRA should have its `r` > 0"
|
||||
LORA_CONFIG["r"] = args.lora_r
|
||||
@ -324,6 +322,9 @@ if __name__ == "__main__":
|
||||
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
|
||||
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||
model = RWKV(args)
|
||||
|
||||
if args.lora:
|
||||
model.requires_grad_(False)
|
||||
for name, module in model.named_modules():
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user