add WebGPU Python Mode (https://github.com/cryscan/web-rwkv-py)
This commit is contained in:
27
backend-python/convert_safetensors.py
vendored
27
backend-python/convert_safetensors.py
vendored
@@ -30,6 +30,33 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
|
||||
kk = list(loaded.keys())
|
||||
version = 4
|
||||
for x in kk:
|
||||
if "ln_x" in x:
|
||||
version = max(5, version)
|
||||
if "gate.weight" in x:
|
||||
version = max(5.1, version)
|
||||
if int(version) == 5 and "att.time_decay" in x:
|
||||
if len(loaded[x].shape) > 1:
|
||||
if loaded[x].shape[1] > 1:
|
||||
version = max(5.2, version)
|
||||
if "time_maa" in x:
|
||||
version = max(6, version)
|
||||
|
||||
if version == 5.1 and "midi" in pt_filename.lower():
|
||||
import numpy as np
|
||||
|
||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
kk = list(loaded.keys())
|
||||
_, n_emb = loaded["emb.weight"].shape
|
||||
for k in kk:
|
||||
if "time_decay" in k or "time_faaaa" in k:
|
||||
# print(k, mm[k].shape)
|
||||
loaded[k] = (
|
||||
loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
|
||||
)
|
||||
|
||||
loaded = {k: v.clone().half() for k, v in loaded.items()}
|
||||
# for k, v in loaded.items():
|
||||
# print(f'{k}\t{v.shape}\t{v.dtype}')
|
||||
|
||||
Reference in New Issue
Block a user