From e33858f110aa718faaefbe0c4bf15d7296cdb5f1 Mon Sep 17 00:00:00 2001 From: josc146 Date: Tue, 26 Dec 2023 23:50:51 +0800 Subject: [PATCH] improve memory usage and speed of convert_safetensors.py --- backend-python/convert_safetensors.py | 58 +++++++++++++-------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/backend-python/convert_safetensors.py b/backend-python/convert_safetensors.py index 131ee8f..feb4874 100644 --- a/backend-python/convert_safetensors.py +++ b/backend-python/convert_safetensors.py @@ -1,9 +1,8 @@ -import json +import collections +import numpy import os -import sys -import copy import torch -from safetensors.torch import load_file, save_file +from safetensors.torch import serialize_file, load_file import argparse @@ -26,7 +25,7 @@ def rename_key(rename, name): def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]): - loaded = torch.load(pt_filename, map_location="cpu") + loaded: collections.OrderedDict = torch.load(pt_filename, map_location="cpu") if "state_dict" in loaded: loaded = loaded["state_dict"] @@ -44,11 +43,9 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names= if "time_maa" in x: version = max(6, version) - if version == 5.1 and "midi" in pt_filename.lower(): - import numpy as np + print(f"Model detected: v{version:.1f}") - np.set_printoptions(precision=4, suppress=True, linewidth=200) - kk = list(loaded.keys()) + if version == 5.1: _, n_emb = loaded["emb.weight"].shape for k in kk: if "time_decay" in k or "time_faaaa" in k: @@ -57,31 +54,34 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names= loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0]) ) - loaded = {k: v.clone().half() for k, v in loaded.items()} - # for k, v in loaded.items(): - # print(f'{k}\t{v.shape}\t{v.dtype}') - - loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()} - # For tensors to be contiguous - for k, v in loaded.items(): + for k in kk: + new_k = rename_key(rename, k).lower() + v = loaded[k].half() + del loaded[k] for transpose_name in transpose_names: if transpose_name in k: - loaded[k] = v.transpose(0, 1) - - loaded = {k: v.clone().half().contiguous() for k, v in loaded.items()} - - for k, v in loaded.items(): - print(f"{k}\t{v.shape}\t{v.dtype}") + v = v.transpose(0, 1) + print(f"{new_k}\t{v.shape}\t{v.dtype}") + loaded[new_k] = { + "dtype": str(v.dtype).split(".")[-1], + "shape": v.shape, + "data": v.numpy().tobytes(), + } dirname = os.path.dirname(sf_filename) os.makedirs(dirname, exist_ok=True) - save_file(loaded, sf_filename, metadata={"format": "pt"}) - reloaded = load_file(sf_filename) - for k in loaded: - pt_tensor = loaded[k] - sf_tensor = reloaded[k] - if not torch.equal(pt_tensor, sf_tensor): - raise RuntimeError(f"The output tensors do not match for key {k}") + serialize_file(loaded, sf_filename, metadata={"format": "pt"}) + # reloaded = load_file(sf_filename) + # for k in loaded: + # pt_tensor = torch.Tensor( + # numpy.frombuffer( + # bytearray(loaded[k]["data"]), + # dtype=getattr(numpy, loaded[k]["dtype"]), + # ).reshape(loaded[k]["shape"]) + # ) + # sf_tensor = reloaded[k] + # if not torch.equal(pt_tensor, sf_tensor): + # raise RuntimeError(f"The output tensors do not match for key {k}") if __name__ == "__main__":