improve memory usage and speed of convert_safetensors.py

This commit is contained in:
josc146 2023-12-26 23:50:51 +08:00
parent da01a33152
commit e33858f110

View File

@ -1,9 +1,8 @@
import json import collections
import numpy
import os import os
import sys
import copy
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import serialize_file, load_file
import argparse import argparse
@ -26,7 +25,7 @@ def rename_key(rename, name):
def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]): def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]):
loaded = torch.load(pt_filename, map_location="cpu") loaded: collections.OrderedDict = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded: if "state_dict" in loaded:
loaded = loaded["state_dict"] loaded = loaded["state_dict"]
@ -44,11 +43,9 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
if "time_maa" in x: if "time_maa" in x:
version = max(6, version) version = max(6, version)
if version == 5.1 and "midi" in pt_filename.lower(): print(f"Model detected: v{version:.1f}")
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200) if version == 5.1:
kk = list(loaded.keys())
_, n_emb = loaded["emb.weight"].shape _, n_emb = loaded["emb.weight"].shape
for k in kk: for k in kk:
if "time_decay" in k or "time_faaaa" in k: if "time_decay" in k or "time_faaaa" in k:
@ -57,31 +54,34 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0]) loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
) )
loaded = {k: v.clone().half() for k, v in loaded.items()} for k in kk:
# for k, v in loaded.items(): new_k = rename_key(rename, k).lower()
# print(f'{k}\t{v.shape}\t{v.dtype}') v = loaded[k].half()
del loaded[k]
loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()}
# For tensors to be contiguous
for k, v in loaded.items():
for transpose_name in transpose_names: for transpose_name in transpose_names:
if transpose_name in k: if transpose_name in k:
loaded[k] = v.transpose(0, 1) v = v.transpose(0, 1)
print(f"{new_k}\t{v.shape}\t{v.dtype}")
loaded = {k: v.clone().half().contiguous() for k, v in loaded.items()} loaded[new_k] = {
"dtype": str(v.dtype).split(".")[-1],
for k, v in loaded.items(): "shape": v.shape,
print(f"{k}\t{v.shape}\t{v.dtype}") "data": v.numpy().tobytes(),
}
dirname = os.path.dirname(sf_filename) dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata={"format": "pt"}) serialize_file(loaded, sf_filename, metadata={"format": "pt"})
reloaded = load_file(sf_filename) # reloaded = load_file(sf_filename)
for k in loaded: # for k in loaded:
pt_tensor = loaded[k] # pt_tensor = torch.Tensor(
sf_tensor = reloaded[k] # numpy.frombuffer(
if not torch.equal(pt_tensor, sf_tensor): # bytearray(loaded[k]["data"]),
raise RuntimeError(f"The output tensors do not match for key {k}") # dtype=getattr(numpy, loaded[k]["dtype"]),
# ).reshape(loaded[k]["shape"])
# )
# sf_tensor = reloaded[k]
# if not torch.equal(pt_tensor, sf_tensor):
# raise RuntimeError(f"The output tensors do not match for key {k}")
if __name__ == "__main__": if __name__ == "__main__":