import torch import sys import time import os import threading import gc def file_cleaner(file): last_pos = 0 def cleaner(): nonlocal last_pos while True: time.sleep(0.1) pos = file.tell() if pos > last_pos: os.posix_fadvise( file.fileno(), last_pos, pos - last_pos, os.POSIX_FADV_DONTNEED ) last_pos = pos return cleaner model_file = open(sys.argv[1], "rb") cleaner = file_cleaner(model_file) cleaner_thread = threading.Thread(target=cleaner, daemon=True) cleaner_thread.start() w = torch.load(model_file, map_location="cpu") gc.collect() n_embd = w["emb.weight"].shape[1] n_layer = 0 keys = list(w.keys()) for x in keys: layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0 n_layer = max(n_layer, layer_id + 1) print(f"--n_layer {n_layer} --n_embd {n_embd}", end="")