fix convert_safetensors.py for rwkv6
This commit is contained in:
parent
18ab8b141f
commit
02bbd18acf
28
backend-python/convert_safetensors.py
vendored
28
backend-python/convert_safetensors.py
vendored
@ -54,19 +54,21 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
|
|||||||
loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
|
loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
|
||||||
)
|
)
|
||||||
|
|
||||||
for k in kk:
|
with torch.no_grad():
|
||||||
new_k = rename_key(rename, k).lower()
|
for k in kk:
|
||||||
v = loaded[k].half()
|
new_k = rename_key(rename, k).lower()
|
||||||
del loaded[k]
|
v = loaded[k].half()
|
||||||
for transpose_name in transpose_names:
|
del loaded[k]
|
||||||
if transpose_name in k:
|
for transpose_name in transpose_names:
|
||||||
v = v.transpose(0, 1)
|
if transpose_name in new_k:
|
||||||
print(f"{new_k}\t{v.shape}\t{v.dtype}")
|
dims = len(v.shape)
|
||||||
loaded[new_k] = {
|
v = v.transpose(dims - 2, dims - 1)
|
||||||
"dtype": str(v.dtype).split(".")[-1],
|
print(f"{new_k}\t{v.shape}\t{v.dtype}")
|
||||||
"shape": v.shape,
|
loaded[new_k] = {
|
||||||
"data": v.numpy().tobytes(),
|
"dtype": str(v.dtype).split(".")[-1],
|
||||||
}
|
"shape": v.shape,
|
||||||
|
"data": v.numpy().tobytes(),
|
||||||
|
}
|
||||||
|
|
||||||
dirname = os.path.dirname(sf_filename)
|
dirname = os.path.dirname(sf_filename)
|
||||||
os.makedirs(dirname, exist_ok=True)
|
os.makedirs(dirname, exist_ok=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user