2023-07-03 17:41:47 +08:00
|
|
|
from collections import OrderedDict
|
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
from typing import Dict
|
|
|
|
import typing
|
|
|
|
import torch
|
|
|
|
|
2023-07-07 20:16:35 +08:00
|
|
|
try:
|
|
|
|
if "-h" in sys.argv or "--help" in sys.argv:
|
|
|
|
print(
|
|
|
|
f"Usage: python3 {sys.argv[0]} [--use-gpu] <lora_alpha> <base_model.pth> <lora_checkpoint.pth> <output.pth>"
|
|
|
|
)
|
2023-07-03 17:41:47 +08:00
|
|
|
|
2023-07-07 20:16:35 +08:00
|
|
|
if sys.argv[1] == "--use-gpu":
|
|
|
|
device = "cuda"
|
|
|
|
lora_alpha, base_model, lora, output = (
|
|
|
|
float(sys.argv[2]),
|
|
|
|
sys.argv[3],
|
|
|
|
sys.argv[4],
|
|
|
|
sys.argv[5],
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
device = "cpu"
|
|
|
|
lora_alpha, base_model, lora, output = (
|
|
|
|
float(sys.argv[1]),
|
|
|
|
sys.argv[2],
|
|
|
|
sys.argv[3],
|
|
|
|
sys.argv[4],
|
|
|
|
)
|
2023-07-03 17:41:47 +08:00
|
|
|
|
2023-07-07 20:16:35 +08:00
|
|
|
with torch.no_grad():
|
|
|
|
w: Dict[str, torch.Tensor] = torch.load(base_model, map_location="cpu")
|
|
|
|
# merge LoRA-only slim checkpoint into the main weights
|
|
|
|
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location="cpu")
|
|
|
|
for k in w_lora.keys():
|
|
|
|
w[k] = w_lora[k]
|
|
|
|
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
|
|
|
|
# merge LoRA weights
|
|
|
|
keys = list(w.keys())
|
|
|
|
for k in keys:
|
|
|
|
if k.endswith(".weight"):
|
|
|
|
prefix = k[: -len(".weight")]
|
|
|
|
lora_A = prefix + ".lora_A"
|
|
|
|
lora_B = prefix + ".lora_B"
|
|
|
|
if lora_A in keys:
|
|
|
|
assert lora_B in keys
|
|
|
|
print(f"merging {lora_A} and {lora_B} into {k}")
|
|
|
|
assert w[lora_B].shape[1] == w[lora_A].shape[0]
|
|
|
|
lora_r = w[lora_B].shape[1]
|
|
|
|
w[k] = w[k].to(device=device)
|
|
|
|
w[lora_A] = w[lora_A].to(device=device)
|
|
|
|
w[lora_B] = w[lora_B].to(device=device)
|
|
|
|
w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
|
|
|
|
output_w[k] = w[k].to(device="cpu", copy=True)
|
|
|
|
del w[k]
|
|
|
|
del w[lora_A]
|
|
|
|
del w[lora_B]
|
|
|
|
continue
|
2023-07-03 17:41:47 +08:00
|
|
|
|
2023-07-07 20:16:35 +08:00
|
|
|
if "lora" not in k:
|
|
|
|
print(f"retaining {k}")
|
|
|
|
output_w[k] = w[k].clone()
|
2023-07-03 17:41:47 +08:00
|
|
|
del w[k]
|
|
|
|
|
2023-07-07 20:16:35 +08:00
|
|
|
torch.save(output_w, output)
|
|
|
|
except Exception as e:
|
|
|
|
with open("error.txt", "w") as f:
|
|
|
|
f.write(str(e))
|