upgrade to webgpu 0.2.2 (https://github.com/josStorer/ai00_rwkv_server)
This commit is contained in:
parent
0331bf47f7
commit
4a192f4057
@ -28,8 +28,7 @@ func (a *App) StartServer(python string, port int, host string, rwkvBeta bool) (
|
|||||||
|
|
||||||
func (a *App) StartWebGPUServer(port int, host string) (string, error) {
|
func (a *App) StartWebGPUServer(port int, host string) (string, error) {
|
||||||
args := []string{"./backend-rust/webgpu_server"}
|
args := []string{"./backend-rust/webgpu_server"}
|
||||||
args = append(args, "-a", "0", "-t", "backend-rust/assets/rwkv_vocab_v20230424.json",
|
args = append(args, "--port", strconv.Itoa(port), "--ip", host)
|
||||||
"--port", strconv.Itoa(port), "--ip", host)
|
|
||||||
return Cmd(args...)
|
return Cmd(args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
32
backend-python/convert_safetensors.py
vendored
32
backend-python/convert_safetensors.py
vendored
@ -18,20 +18,31 @@ parser.add_argument(
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def convert_file(
|
def rename_key(rename, name):
|
||||||
pt_filename: str,
|
for k, v in rename.items():
|
||||||
sf_filename: str,
|
if k in name:
|
||||||
):
|
name = name.replace(k, v)
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename={}):
|
||||||
loaded = torch.load(pt_filename, map_location="cpu")
|
loaded = torch.load(pt_filename, map_location="cpu")
|
||||||
if "state_dict" in loaded:
|
if "state_dict" in loaded:
|
||||||
loaded = loaded["state_dict"]
|
loaded = loaded["state_dict"]
|
||||||
|
|
||||||
loaded = {k: v.clone().half() for k, v in loaded.items()}
|
loaded = {k: v.clone().half() for k, v in loaded.items()}
|
||||||
for k, v in loaded.items():
|
# for k, v in loaded.items():
|
||||||
print(f"{k}\t{v.shape}\t{v.dtype}")
|
# print(f'{k}\t{v.shape}\t{v.dtype}')
|
||||||
|
|
||||||
# For tensors to be contiguous
|
# For tensors to be contiguous
|
||||||
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
for k, v in loaded.items():
|
||||||
|
for transpose_name in transpose_names:
|
||||||
|
if transpose_name in k:
|
||||||
|
loaded[k] = v.transpose(0, 1)
|
||||||
|
loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()}
|
||||||
|
|
||||||
|
for k, v in loaded.items():
|
||||||
|
print(f"{k}\t{v.shape}\t{v.dtype}")
|
||||||
|
|
||||||
dirname = os.path.dirname(sf_filename)
|
dirname = os.path.dirname(sf_filename)
|
||||||
os.makedirs(dirname, exist_ok=True)
|
os.makedirs(dirname, exist_ok=True)
|
||||||
@ -46,7 +57,12 @@ def convert_file(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
convert_file(args.input, args.output)
|
convert_file(
|
||||||
|
args.input,
|
||||||
|
args.output,
|
||||||
|
["lora_A"],
|
||||||
|
{"time_faaaa": "time_first", "lora_A": "lora.0", "lora_B": "lora.1"},
|
||||||
|
)
|
||||||
print(f"Saved to {args.output}")
|
print(f"Saved to {args.output}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
with open("error.txt", "w") as f:
|
with open("error.txt", "w") as f:
|
||||||
|
Loading…
Reference in New Issue
Block a user