diff --git a/finetune/get_layer_and_embd.py b/finetune/get_layer_and_embd.py index 17e3edd..04498aa 100644 --- a/finetune/get_layer_and_embd.py +++ b/finetune/get_layer_and_embd.py @@ -23,6 +23,7 @@ def file_cleaner(file): return cleaner +expected_max_version = float(sys.argv[2]) if len(sys.argv) > 2 else 100 model_file = open(sys.argv[1], "rb") cleaner = file_cleaner(model_file) cleaner_thread = threading.Thread(target=cleaner, daemon=True) @@ -34,8 +35,23 @@ gc.collect() n_embd = w["emb.weight"].shape[1] n_layer = 0 keys = list(w.keys()) +version = 4 for x in keys: layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0 n_layer = max(n_layer, layer_id + 1) -print(f"--n_layer {n_layer} --n_embd {n_embd}", end="") + if "ln_x" in x: + version = max(5, version) + if "gate.weight" in x: + version = max(5.1, version) + if int(version) == 5 and "att.time_decay" in x: + if len(w[x].shape) > 1: + if w[x].shape[1] > 1: + version = max(5.2, version) + if "time_maa" in x: + version = max(6, version) + +if version <= expected_max_version: + print(f"--n_layer {n_layer} --n_embd {n_embd}", end="") +else: + raise Exception(f"RWKV{version} is not supported") diff --git a/finetune/install-wsl-dep-and-train.sh b/finetune/install-wsl-dep-and-train.sh index 6771ad3..83739f2 100644 --- a/finetune/install-wsl-dep-and-train.sh +++ b/finetune/install-wsl-dep-and-train.sh @@ -47,7 +47,7 @@ else fi echo "loading $loadModel" -modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel) +modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 4) echo $modelInfo if [[ $modelInfo =~ "--n_layer" ]]; then python3 ./finetune/lora/train.py $modelInfo $@ --proj_dir lora-models --data_type binidx --lora \ diff --git a/frontend/src/_locales/ja/main.json b/frontend/src/_locales/ja/main.json index b6befe2..e0eec78 100644 --- a/frontend/src/_locales/ja/main.json +++ b/frontend/src/_locales/ja/main.json @@ -302,5 +302,6 @@ "Content Duration": "内容の長さ", "Please select a MIDI device first": "まずMIDIデバイスを選択してください", "Piano is the main instrument": "ピアノはメインの楽器です", - "Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Lossが大きすぎます、トレーニングデータを確認し、GPUドライバが最新であることを確認してください。" + "Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Lossが大きすぎます、トレーニングデータを確認し、GPUドライバが最新であることを確認してください。", + "This version of RWKV is not supported yet.": "このバージョンのRWKVはまだサポートされていません。" } \ No newline at end of file diff --git a/frontend/src/_locales/zh-hans/main.json b/frontend/src/_locales/zh-hans/main.json index 105961e..ddd771f 100644 --- a/frontend/src/_locales/zh-hans/main.json +++ b/frontend/src/_locales/zh-hans/main.json @@ -302,5 +302,6 @@ "Content Duration": "内容时长", "Please select a MIDI device first": "请先选择一个MIDI设备", "Piano is the main instrument": "钢琴为主", - "Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Loss过高,请检查训练数据,并确保你的显卡驱动是最新的" + "Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Loss过高,请检查训练数据,并确保你的显卡驱动是最新的", + "This version of RWKV is not supported yet.": "暂不支持此版本的RWKV" } \ No newline at end of file diff --git a/frontend/src/pages/Train.tsx b/frontend/src/pages/Train.tsx index ba8e9ea..571d6ce 100644 --- a/frontend/src/pages/Train.tsx +++ b/frontend/src/pages/Train.tsx @@ -139,6 +139,7 @@ const errorsMap = Object.entries({ 'cuda_home environment variable is not set': 'Matched CUDA is not installed', 'unsupported gpu architecture': 'Matched CUDA is not installed', 'error building extension \'fused_adam\'': 'Matched CUDA is not installed', + 'rwkv{version} is not supported': 'This version of RWKV is not supported yet.', 'modelinfo is invalid': 'Failed to load model, try to increase the virtual memory (Swap of WSL) or use a smaller base model.' });