This commit is contained in:
92
finetune/lora/v6/merge/merge.py
vendored
Normal file
92
finetune/lora/v6/merge/merge.py
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
import typing
|
||||
import torch
|
||||
import bitsandbytes as bnb
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--type", default="pissa", type=str)
|
||||
parser.add_argument("--base_model", default="", type=str)
|
||||
parser.add_argument("--lora_init", default="none", type=str)
|
||||
parser.add_argument("--lora_checkpoint", default="", type=str)
|
||||
parser.add_argument("--output", default="", type=str)
|
||||
parser.add_argument("--quant", default="none", type=str)
|
||||
parser.add_argument("--device", default="cuda", type=str)
|
||||
parser.add_argument("--lora_alpha", default=16, type=int)
|
||||
args = parser.parse_args()
|
||||
device= args.device
|
||||
base_model = args.base_model
|
||||
init_lora= args.lora_init
|
||||
lora= args.lora_checkpoint
|
||||
output= args.output
|
||||
quant= args.quant
|
||||
lora_alpha = args.lora_alpha
|
||||
|
||||
with torch.no_grad():
|
||||
w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
|
||||
# merge LoRA-only slim checkpoint into the main weights
|
||||
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
|
||||
|
||||
if args.type=='pissa':
|
||||
w_init_lora: Dict[str, torch.Tensor] = torch.load(init_lora, map_location='cpu')
|
||||
for k in w_lora.keys():
|
||||
w[k] = w_lora[k]
|
||||
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
|
||||
# merge LoRA weights
|
||||
keys = list(w.keys())
|
||||
for k in keys:
|
||||
if k.endswith('.weight'):
|
||||
prefix = k[:-len('.weight')]
|
||||
lora_A = prefix + '.lora_A'
|
||||
lora_B = prefix + '.lora_B'
|
||||
init_lora_A = prefix + '.init_lora_A'
|
||||
init_lora_B = prefix + '.init_lora_B'
|
||||
if lora_A in keys:
|
||||
assert lora_B in keys
|
||||
print(f'merging {lora_A} and {lora_B} into {k}')
|
||||
assert w[lora_B].shape[1] == w[lora_A].shape[0]
|
||||
lora_r = w[lora_B].shape[1]
|
||||
w[k] = w[k].to(device=device)
|
||||
w[lora_A] = w[lora_A].to(device=device)
|
||||
w[lora_B] = w[lora_B].to(device=device)
|
||||
|
||||
if args.type=='pissa':
|
||||
w_init_lora[init_lora_A] = w_init_lora[init_lora_A].to(device=device)
|
||||
w_init_lora[init_lora_B] = w_init_lora[init_lora_B].to(device=device)
|
||||
if quant=='4bit':
|
||||
qw,qs = bnb.functional.quantize_4bit(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A])
|
||||
w[k] = (bnb.functional.dequantize_4bit(qw,quant_state=qs)).to(dtype=torch.bfloat16)
|
||||
elif quant == 'nf4':
|
||||
qw,qs = bnb.functional.quantize_nf4(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A])
|
||||
w[k] = (bnb.functional.dequantize_nf4(qw,quant_state=qs)).to(dtype=torch.bfloat16)
|
||||
elif quant == 'fp4':
|
||||
qw,qs = bnb.functional.quantize_fp4(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A])
|
||||
w[k] = (bnb.functional.dequantize_fp4(qw,quant_state=qs)).to(dtype=torch.bfloat16)
|
||||
else:
|
||||
w[k] = (w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]).to(dtype=torch.bfloat16)
|
||||
w[k] += w[lora_B] @ w[lora_A]
|
||||
else:
|
||||
if quant=='4bit':
|
||||
qw,qs = bnb.functional.quantize_4bit(w[k])
|
||||
w[k] = (bnb.functional.dequantize_4bit(qw,quant_state=qs)).to(dtype=torch.bfloat16)
|
||||
elif quant=='nf4':
|
||||
qw,qs = bnb.functional.quantize_nf4(w[k])
|
||||
w[k] = (bnb.functional.dequantize_nf4(qw,quant_state=qs)).to(dtype=torch.bfloat16)
|
||||
elif quant=='fp4':
|
||||
qw,qs = bnb.functional.quantize_fp4(w[k])
|
||||
w[k] = (bnb.functional.dequantize_fp4(qw,quant_state=qs)).to(dtype=torch.bfloat16)
|
||||
w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
|
||||
output_w[k] = w[k].to(device='cpu', copy=True)
|
||||
del w[k]
|
||||
del w[lora_A]
|
||||
del w[lora_B]
|
||||
continue
|
||||
|
||||
if 'lora' not in k:
|
||||
print(f'retaining {k}')
|
||||
output_w[k] = w[k].clone()
|
||||
del w[k]
|
||||
torch.save(output_w, output)
|
||||
52
finetune/lora/v6/merge/merge_lora.py
vendored
Normal file
52
finetune/lora/v6/merge/merge_lora.py
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
import typing
|
||||
import torch
|
||||
|
||||
if '-h' in sys.argv or '--help' in sys.argv:
|
||||
print(f'Usage: python3 {sys.argv[0]} [--use-gpu] <lora_alpha> <base_model.pth> <lora_checkpoint.pth> <output.pth>')
|
||||
|
||||
if sys.argv[1] == '--use-gpu':
|
||||
device = 'cuda'
|
||||
lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5]
|
||||
else:
|
||||
device = 'cpu'
|
||||
lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4]
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
|
||||
# merge LoRA-only slim checkpoint into the main weights
|
||||
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
|
||||
for k in w_lora.keys():
|
||||
w[k] = w_lora[k]
|
||||
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
|
||||
# merge LoRA weights
|
||||
keys = list(w.keys())
|
||||
for k in keys:
|
||||
if k.endswith('.weight'):
|
||||
prefix = k[:-len('.weight')]
|
||||
lora_A = prefix + '.lora_A'
|
||||
lora_B = prefix + '.lora_B'
|
||||
if lora_A in keys:
|
||||
assert lora_B in keys
|
||||
print(f'merging {lora_A} and {lora_B} into {k}')
|
||||
assert w[lora_B].shape[1] == w[lora_A].shape[0]
|
||||
lora_r = w[lora_B].shape[1]
|
||||
w[k] = w[k].to(device=device)
|
||||
w[lora_A] = w[lora_A].to(device=device)
|
||||
w[lora_B] = w[lora_B].to(device=device)
|
||||
w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
|
||||
output_w[k] = w[k].to(device='cpu', copy=True)
|
||||
del w[k]
|
||||
del w[lora_A]
|
||||
del w[lora_B]
|
||||
continue
|
||||
|
||||
if 'lora' not in k:
|
||||
print(f'retaining {k}')
|
||||
output_w[k] = w[k].clone()
|
||||
del w[k]
|
||||
torch.save(output_w, output)
|
||||
58
finetune/lora/v6/merge/merge_pissa.py
vendored
Normal file
58
finetune/lora/v6/merge/merge_pissa.py
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
import typing
|
||||
import torch
|
||||
|
||||
if '-h' in sys.argv or '--help' in sys.argv:
|
||||
print(f'Usage: python3 {sys.argv[0]} [--use-gpu] <base_model.pth> <lora_init.pth> <lora_checkpoint.pth> <output.pth>')
|
||||
|
||||
if sys.argv[1] == '--use-gpu':
|
||||
device = 'cuda'
|
||||
base_model, init_lora, lora, output = sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5]
|
||||
else:
|
||||
device = 'cpu'
|
||||
base_model, init_lora, lora, output = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
|
||||
# merge LoRA-only slim checkpoint into the main weights
|
||||
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
|
||||
w_init_lora: Dict[str, torch.Tensor] = torch.load(init_lora, map_location='cpu')
|
||||
for k in w_lora.keys():
|
||||
w[k] = w_lora[k]
|
||||
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
|
||||
# merge LoRA weights
|
||||
keys = list(w.keys())
|
||||
for k in keys:
|
||||
if k.endswith('.weight'):
|
||||
prefix = k[:-len('.weight')]
|
||||
lora_A = prefix + '.lora_A'
|
||||
lora_B = prefix + '.lora_B'
|
||||
init_lora_A = prefix + '.init_lora_A'
|
||||
init_lora_B = prefix + '.init_lora_B'
|
||||
if lora_A in keys:
|
||||
assert lora_B in keys
|
||||
print(f'merging {lora_A} and {lora_B} into {k}')
|
||||
assert w[lora_B].shape[1] == w[lora_A].shape[0]
|
||||
lora_r = w[lora_B].shape[1]
|
||||
w[k] = w[k].to(device=device)
|
||||
w[lora_A] = w[lora_A].to(device=device)
|
||||
w[lora_B] = w[lora_B].to(device=device)
|
||||
w_init_lora[init_lora_A] = w_init_lora[init_lora_A].to(device=device)
|
||||
w_init_lora[init_lora_B] = w_init_lora[init_lora_B].to(device=device)
|
||||
w[k] = (w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]).to(dtype=torch.bfloat16)
|
||||
w[k] += w[lora_B] @ w[lora_A]
|
||||
output_w[k] = w[k].to(device='cpu', copy=True)
|
||||
del w[k]
|
||||
del w[lora_A]
|
||||
del w[lora_B]
|
||||
continue
|
||||
|
||||
if 'lora' not in k:
|
||||
print(f'retaining {k}')
|
||||
output_w[k] = w[k].clone()
|
||||
del w[k]
|
||||
torch.save(output_w, output)
|
||||
36
finetune/lora/v6/merge/merge_state.py
vendored
Normal file
36
finetune/lora/v6/merge/merge_state.py
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
import typing
|
||||
import torch
|
||||
import bitsandbytes as bnb
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--base_model", default="", type=str)
|
||||
parser.add_argument("--state_checkpoint", default="", type=str)
|
||||
parser.add_argument("--output", default="", type=str)
|
||||
# parser.add_argument("--quant", default="none", type=str)
|
||||
parser.add_argument("--device", default="cuda", type=str)
|
||||
# parser.add_argument("--lora_alpha", default=16, type=int)
|
||||
args = parser.parse_args()
|
||||
device= args.device
|
||||
base_model = args.base_model
|
||||
state= args.state_checkpoint
|
||||
output= args.output
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
|
||||
# merge LoRA-only slim checkpoint into the main weights
|
||||
w_state: Dict[str, torch.Tensor] = torch.load(state, map_location='cpu')
|
||||
|
||||
for k in w_state.keys():
|
||||
print(k)
|
||||
w[k] = w_state[k]
|
||||
# merge LoRA weights
|
||||
for k in w.keys():
|
||||
print(k)
|
||||
|
||||
torch.save(w, output)
|
||||
Reference in New Issue
Block a user