lora finetune version check
This commit is contained in:
@@ -23,6 +23,7 @@ def file_cleaner(file):
|
||||
return cleaner
|
||||
|
||||
|
||||
expected_max_version = float(sys.argv[2]) if len(sys.argv) > 2 else 100
|
||||
model_file = open(sys.argv[1], "rb")
|
||||
cleaner = file_cleaner(model_file)
|
||||
cleaner_thread = threading.Thread(target=cleaner, daemon=True)
|
||||
@@ -34,8 +35,23 @@ gc.collect()
|
||||
n_embd = w["emb.weight"].shape[1]
|
||||
n_layer = 0
|
||||
keys = list(w.keys())
|
||||
version = 4
|
||||
for x in keys:
|
||||
layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
|
||||
n_layer = max(n_layer, layer_id + 1)
|
||||
|
||||
print(f"--n_layer {n_layer} --n_embd {n_embd}", end="")
|
||||
if "ln_x" in x:
|
||||
version = max(5, version)
|
||||
if "gate.weight" in x:
|
||||
version = max(5.1, version)
|
||||
if int(version) == 5 and "att.time_decay" in x:
|
||||
if len(w[x].shape) > 1:
|
||||
if w[x].shape[1] > 1:
|
||||
version = max(5.2, version)
|
||||
if "time_maa" in x:
|
||||
version = max(6, version)
|
||||
|
||||
if version <= expected_max_version:
|
||||
print(f"--n_layer {n_layer} --n_embd {n_embd}", end="")
|
||||
else:
|
||||
raise Exception(f"RWKV{version} is not supported")
|
||||
|
||||
@@ -47,7 +47,7 @@ else
|
||||
fi
|
||||
|
||||
echo "loading $loadModel"
|
||||
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel)
|
||||
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 4)
|
||||
echo $modelInfo
|
||||
if [[ $modelInfo =~ "--n_layer" ]]; then
|
||||
python3 ./finetune/lora/train.py $modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
|
||||
|
||||
Reference in New Issue
Block a user