improve memory usage and speed of convert_safetensors.py
This commit is contained in:
parent
da01a33152
commit
e33858f110
58
backend-python/convert_safetensors.py
vendored
58
backend-python/convert_safetensors.py
vendored
@ -1,9 +1,8 @@
|
|||||||
import json
|
import collections
|
||||||
|
import numpy
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import copy
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import serialize_file, load_file
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
@ -26,7 +25,7 @@ def rename_key(rename, name):
|
|||||||
|
|
||||||
|
|
||||||
def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]):
|
def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]):
|
||||||
loaded = torch.load(pt_filename, map_location="cpu")
|
loaded: collections.OrderedDict = torch.load(pt_filename, map_location="cpu")
|
||||||
if "state_dict" in loaded:
|
if "state_dict" in loaded:
|
||||||
loaded = loaded["state_dict"]
|
loaded = loaded["state_dict"]
|
||||||
|
|
||||||
@ -44,11 +43,9 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
|
|||||||
if "time_maa" in x:
|
if "time_maa" in x:
|
||||||
version = max(6, version)
|
version = max(6, version)
|
||||||
|
|
||||||
if version == 5.1 and "midi" in pt_filename.lower():
|
print(f"Model detected: v{version:.1f}")
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
if version == 5.1:
|
||||||
kk = list(loaded.keys())
|
|
||||||
_, n_emb = loaded["emb.weight"].shape
|
_, n_emb = loaded["emb.weight"].shape
|
||||||
for k in kk:
|
for k in kk:
|
||||||
if "time_decay" in k or "time_faaaa" in k:
|
if "time_decay" in k or "time_faaaa" in k:
|
||||||
@ -57,31 +54,34 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
|
|||||||
loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
|
loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
|
||||||
)
|
)
|
||||||
|
|
||||||
loaded = {k: v.clone().half() for k, v in loaded.items()}
|
for k in kk:
|
||||||
# for k, v in loaded.items():
|
new_k = rename_key(rename, k).lower()
|
||||||
# print(f'{k}\t{v.shape}\t{v.dtype}')
|
v = loaded[k].half()
|
||||||
|
del loaded[k]
|
||||||
loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()}
|
|
||||||
# For tensors to be contiguous
|
|
||||||
for k, v in loaded.items():
|
|
||||||
for transpose_name in transpose_names:
|
for transpose_name in transpose_names:
|
||||||
if transpose_name in k:
|
if transpose_name in k:
|
||||||
loaded[k] = v.transpose(0, 1)
|
v = v.transpose(0, 1)
|
||||||
|
print(f"{new_k}\t{v.shape}\t{v.dtype}")
|
||||||
loaded = {k: v.clone().half().contiguous() for k, v in loaded.items()}
|
loaded[new_k] = {
|
||||||
|
"dtype": str(v.dtype).split(".")[-1],
|
||||||
for k, v in loaded.items():
|
"shape": v.shape,
|
||||||
print(f"{k}\t{v.shape}\t{v.dtype}")
|
"data": v.numpy().tobytes(),
|
||||||
|
}
|
||||||
|
|
||||||
dirname = os.path.dirname(sf_filename)
|
dirname = os.path.dirname(sf_filename)
|
||||||
os.makedirs(dirname, exist_ok=True)
|
os.makedirs(dirname, exist_ok=True)
|
||||||
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
serialize_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||||
reloaded = load_file(sf_filename)
|
# reloaded = load_file(sf_filename)
|
||||||
for k in loaded:
|
# for k in loaded:
|
||||||
pt_tensor = loaded[k]
|
# pt_tensor = torch.Tensor(
|
||||||
sf_tensor = reloaded[k]
|
# numpy.frombuffer(
|
||||||
if not torch.equal(pt_tensor, sf_tensor):
|
# bytearray(loaded[k]["data"]),
|
||||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
# dtype=getattr(numpy, loaded[k]["dtype"]),
|
||||||
|
# ).reshape(loaded[k]["shape"])
|
||||||
|
# )
|
||||||
|
# sf_tensor = reloaded[k]
|
||||||
|
# if not torch.equal(pt_tensor, sf_tensor):
|
||||||
|
# raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user