improve memory usage and speed of convert_safetensors.py
This commit is contained in:
parent
da01a33152
commit
e33858f110
58
backend-python/convert_safetensors.py
vendored
58
backend-python/convert_safetensors.py
vendored
@ -1,9 +1,8 @@
|
||||
import json
|
||||
import collections
|
||||
import numpy
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from safetensors.torch import serialize_file, load_file
|
||||
|
||||
import argparse
|
||||
|
||||
@ -26,7 +25,7 @@ def rename_key(rename, name):
|
||||
|
||||
|
||||
def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]):
|
||||
loaded = torch.load(pt_filename, map_location="cpu")
|
||||
loaded: collections.OrderedDict = torch.load(pt_filename, map_location="cpu")
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
|
||||
@ -44,11 +43,9 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
|
||||
if "time_maa" in x:
|
||||
version = max(6, version)
|
||||
|
||||
if version == 5.1 and "midi" in pt_filename.lower():
|
||||
import numpy as np
|
||||
print(f"Model detected: v{version:.1f}")
|
||||
|
||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
kk = list(loaded.keys())
|
||||
if version == 5.1:
|
||||
_, n_emb = loaded["emb.weight"].shape
|
||||
for k in kk:
|
||||
if "time_decay" in k or "time_faaaa" in k:
|
||||
@ -57,31 +54,34 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
|
||||
loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
|
||||
)
|
||||
|
||||
loaded = {k: v.clone().half() for k, v in loaded.items()}
|
||||
# for k, v in loaded.items():
|
||||
# print(f'{k}\t{v.shape}\t{v.dtype}')
|
||||
|
||||
loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()}
|
||||
# For tensors to be contiguous
|
||||
for k, v in loaded.items():
|
||||
for k in kk:
|
||||
new_k = rename_key(rename, k).lower()
|
||||
v = loaded[k].half()
|
||||
del loaded[k]
|
||||
for transpose_name in transpose_names:
|
||||
if transpose_name in k:
|
||||
loaded[k] = v.transpose(0, 1)
|
||||
|
||||
loaded = {k: v.clone().half().contiguous() for k, v in loaded.items()}
|
||||
|
||||
for k, v in loaded.items():
|
||||
print(f"{k}\t{v.shape}\t{v.dtype}")
|
||||
v = v.transpose(0, 1)
|
||||
print(f"{new_k}\t{v.shape}\t{v.dtype}")
|
||||
loaded[new_k] = {
|
||||
"dtype": str(v.dtype).split(".")[-1],
|
||||
"shape": v.shape,
|
||||
"data": v.numpy().tobytes(),
|
||||
}
|
||||
|
||||
dirname = os.path.dirname(sf_filename)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||
reloaded = load_file(sf_filename)
|
||||
for k in loaded:
|
||||
pt_tensor = loaded[k]
|
||||
sf_tensor = reloaded[k]
|
||||
if not torch.equal(pt_tensor, sf_tensor):
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
serialize_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||
# reloaded = load_file(sf_filename)
|
||||
# for k in loaded:
|
||||
# pt_tensor = torch.Tensor(
|
||||
# numpy.frombuffer(
|
||||
# bytearray(loaded[k]["data"]),
|
||||
# dtype=getattr(numpy, loaded[k]["dtype"]),
|
||||
# ).reshape(loaded[k]["shape"])
|
||||
# )
|
||||
# sf_tensor = reloaded[k]
|
||||
# if not torch.equal(pt_tensor, sf_tensor):
|
||||
# raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user