fix v6 lora (c03cdbbdaf
)
This commit is contained in:
parent
acf5d02104
commit
c5077f4ebc
5
finetune/lora/v6/train.py
vendored
5
finetune/lora/v6/train.py
vendored
@ -314,8 +314,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
from src.model import RWKV, LORA_CONFIG, LoraLinear
|
from src.model import RWKV, LORA_CONFIG, LoraLinear
|
||||||
|
|
||||||
model = RWKV(args)
|
|
||||||
|
|
||||||
if args.lora:
|
if args.lora:
|
||||||
assert args.lora_r > 0, "LoRA should have its `r` > 0"
|
assert args.lora_r > 0, "LoRA should have its `r` > 0"
|
||||||
LORA_CONFIG["r"] = args.lora_r
|
LORA_CONFIG["r"] = args.lora_r
|
||||||
@ -324,6 +322,9 @@ if __name__ == "__main__":
|
|||||||
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
|
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
|
||||||
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||||
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||||
|
model = RWKV(args)
|
||||||
|
|
||||||
|
if args.lora:
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user