RWKV-Runner/finetune/lora/v6/merge/merge.py

92 lines
4.2 KiB
Python
Raw Normal View History

from collections import OrderedDict
import os
import sys
from typing import Dict
import typing
import torch
import bitsandbytes as bnb
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--type", default="pissa", type=str)
parser.add_argument("--base_model", default="", type=str)
parser.add_argument("--lora_init", default="none", type=str)
parser.add_argument("--lora_checkpoint", default="", type=str)
parser.add_argument("--output", default="", type=str)
parser.add_argument("--quant", default="none", type=str)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--lora_alpha", default=16, type=int)
args = parser.parse_args()
device= args.device
base_model = args.base_model
init_lora= args.lora_init
lora= args.lora_checkpoint
output= args.output
quant= args.quant
lora_alpha = args.lora_alpha
with torch.no_grad():
w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
# merge LoRA-only slim checkpoint into the main weights
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
if args.type=='pissa':
w_init_lora: Dict[str, torch.Tensor] = torch.load(init_lora, map_location='cpu')
for k in w_lora.keys():
w[k] = w_lora[k]
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
# merge LoRA weights
keys = list(w.keys())
for k in keys:
if k.endswith('.weight'):
prefix = k[:-len('.weight')]
lora_A = prefix + '.lora_A'
lora_B = prefix + '.lora_B'
init_lora_A = prefix + '.init_lora_A'
init_lora_B = prefix + '.init_lora_B'
if lora_A in keys:
assert lora_B in keys
print(f'merging {lora_A} and {lora_B} into {k}')
assert w[lora_B].shape[1] == w[lora_A].shape[0]
lora_r = w[lora_B].shape[1]
w[k] = w[k].to(device=device)
w[lora_A] = w[lora_A].to(device=device)
w[lora_B] = w[lora_B].to(device=device)
if args.type=='pissa':
w_init_lora[init_lora_A] = w_init_lora[init_lora_A].to(device=device)
w_init_lora[init_lora_B] = w_init_lora[init_lora_B].to(device=device)
if quant=='4bit':
qw,qs = bnb.functional.quantize_4bit(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A])
w[k] = (bnb.functional.dequantize_4bit(qw,quant_state=qs)).to(dtype=torch.bfloat16)
elif quant == 'nf4':
qw,qs = bnb.functional.quantize_nf4(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A])
w[k] = (bnb.functional.dequantize_nf4(qw,quant_state=qs)).to(dtype=torch.bfloat16)
elif quant == 'fp4':
qw,qs = bnb.functional.quantize_fp4(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A])
w[k] = (bnb.functional.dequantize_fp4(qw,quant_state=qs)).to(dtype=torch.bfloat16)
else:
w[k] = (w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]).to(dtype=torch.bfloat16)
w[k] += w[lora_B] @ w[lora_A]
else:
if quant=='4bit':
qw,qs = bnb.functional.quantize_4bit(w[k])
w[k] = (bnb.functional.dequantize_4bit(qw,quant_state=qs)).to(dtype=torch.bfloat16)
elif quant=='nf4':
qw,qs = bnb.functional.quantize_nf4(w[k])
w[k] = (bnb.functional.dequantize_nf4(qw,quant_state=qs)).to(dtype=torch.bfloat16)
elif quant=='fp4':
qw,qs = bnb.functional.quantize_fp4(w[k])
w[k] = (bnb.functional.dequantize_fp4(qw,quant_state=qs)).to(dtype=torch.bfloat16)
w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
output_w[k] = w[k].to(device='cpu', copy=True)
del w[k]
del w[lora_A]
del w[lora_B]
continue
if 'lora' not in k:
print(f'retaining {k}')
output_w[k] = w[k].clone()
del w[k]
torch.save(output_w, output)