This commit is contained in:
josc146 2023-10-25 21:02:44 +08:00
parent 0331bf47f7
commit 4a192f4057
2 changed files with 25 additions and 10 deletions

View File

@ -28,8 +28,7 @@ func (a *App) StartServer(python string, port int, host string, rwkvBeta bool) (
func (a *App) StartWebGPUServer(port int, host string) (string, error) { func (a *App) StartWebGPUServer(port int, host string) (string, error) {
args := []string{"./backend-rust/webgpu_server"} args := []string{"./backend-rust/webgpu_server"}
args = append(args, "-a", "0", "-t", "backend-rust/assets/rwkv_vocab_v20230424.json", args = append(args, "--port", strconv.Itoa(port), "--ip", host)
"--port", strconv.Itoa(port), "--ip", host)
return Cmd(args...) return Cmd(args...)
} }

View File

@ -18,20 +18,31 @@ parser.add_argument(
args = parser.parse_args() args = parser.parse_args()
def convert_file( def rename_key(rename, name):
pt_filename: str, for k, v in rename.items():
sf_filename: str, if k in name:
): name = name.replace(k, v)
return name
def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename={}):
loaded = torch.load(pt_filename, map_location="cpu") loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded: if "state_dict" in loaded:
loaded = loaded["state_dict"] loaded = loaded["state_dict"]
loaded = {k: v.clone().half() for k, v in loaded.items()} loaded = {k: v.clone().half() for k, v in loaded.items()}
for k, v in loaded.items(): # for k, v in loaded.items():
print(f"{k}\t{v.shape}\t{v.dtype}") # print(f'{k}\t{v.shape}\t{v.dtype}')
# For tensors to be contiguous # For tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()} for k, v in loaded.items():
for transpose_name in transpose_names:
if transpose_name in k:
loaded[k] = v.transpose(0, 1)
loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()}
for k, v in loaded.items():
print(f"{k}\t{v.shape}\t{v.dtype}")
dirname = os.path.dirname(sf_filename) dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
@ -46,7 +57,12 @@ def convert_file(
if __name__ == "__main__": if __name__ == "__main__":
try: try:
convert_file(args.input, args.output) convert_file(
args.input,
args.output,
["lora_A"],
{"time_faaaa": "time_first", "lora_A": "lora.0", "lora_B": "lora.1"},
)
print(f"Saved to {args.output}") print(f"Saved to {args.output}")
except Exception as e: except Exception as e:
with open("error.txt", "w") as f: with open("error.txt", "w") as f: