lora finetune version check
This commit is contained in:
parent
24cc8be085
commit
a8b4f0bb7e
@ -23,6 +23,7 @@ def file_cleaner(file):
|
||||
return cleaner
|
||||
|
||||
|
||||
expected_max_version = float(sys.argv[2]) if len(sys.argv) > 2 else 100
|
||||
model_file = open(sys.argv[1], "rb")
|
||||
cleaner = file_cleaner(model_file)
|
||||
cleaner_thread = threading.Thread(target=cleaner, daemon=True)
|
||||
@ -34,8 +35,23 @@ gc.collect()
|
||||
n_embd = w["emb.weight"].shape[1]
|
||||
n_layer = 0
|
||||
keys = list(w.keys())
|
||||
version = 4
|
||||
for x in keys:
|
||||
layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
|
||||
n_layer = max(n_layer, layer_id + 1)
|
||||
|
||||
print(f"--n_layer {n_layer} --n_embd {n_embd}", end="")
|
||||
if "ln_x" in x:
|
||||
version = max(5, version)
|
||||
if "gate.weight" in x:
|
||||
version = max(5.1, version)
|
||||
if int(version) == 5 and "att.time_decay" in x:
|
||||
if len(w[x].shape) > 1:
|
||||
if w[x].shape[1] > 1:
|
||||
version = max(5.2, version)
|
||||
if "time_maa" in x:
|
||||
version = max(6, version)
|
||||
|
||||
if version <= expected_max_version:
|
||||
print(f"--n_layer {n_layer} --n_embd {n_embd}", end="")
|
||||
else:
|
||||
raise Exception(f"RWKV{version} is not supported")
|
||||
|
@ -47,7 +47,7 @@ else
|
||||
fi
|
||||
|
||||
echo "loading $loadModel"
|
||||
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel)
|
||||
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 4)
|
||||
echo $modelInfo
|
||||
if [[ $modelInfo =~ "--n_layer" ]]; then
|
||||
python3 ./finetune/lora/train.py $modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
|
||||
|
@ -302,5 +302,6 @@
|
||||
"Content Duration": "内容の長さ",
|
||||
"Please select a MIDI device first": "まずMIDIデバイスを選択してください",
|
||||
"Piano is the main instrument": "ピアノはメインの楽器です",
|
||||
"Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Lossが大きすぎます、トレーニングデータを確認し、GPUドライバが最新であることを確認してください。"
|
||||
"Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Lossが大きすぎます、トレーニングデータを確認し、GPUドライバが最新であることを確認してください。",
|
||||
"This version of RWKV is not supported yet.": "このバージョンのRWKVはまだサポートされていません。"
|
||||
}
|
@ -302,5 +302,6 @@
|
||||
"Content Duration": "内容时长",
|
||||
"Please select a MIDI device first": "请先选择一个MIDI设备",
|
||||
"Piano is the main instrument": "钢琴为主",
|
||||
"Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Loss过高,请检查训练数据,并确保你的显卡驱动是最新的"
|
||||
"Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Loss过高,请检查训练数据,并确保你的显卡驱动是最新的",
|
||||
"This version of RWKV is not supported yet.": "暂不支持此版本的RWKV"
|
||||
}
|
@ -139,6 +139,7 @@ const errorsMap = Object.entries({
|
||||
'cuda_home environment variable is not set': 'Matched CUDA is not installed',
|
||||
'unsupported gpu architecture': 'Matched CUDA is not installed',
|
||||
'error building extension \'fused_adam\'': 'Matched CUDA is not installed',
|
||||
'rwkv{version} is not supported': 'This version of RWKV is not supported yet.',
|
||||
'modelinfo is invalid': 'Failed to load model, try to increase the virtual memory (Swap of WSL) or use a smaller base model.'
|
||||
});
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user