from lib2to3.pgen2 import token import os import torch import numpy as np import shutil import struct from functools import lru_cache from itertools import accumulate def print_rank_0(*message): pass # """If distributed is initialized print only on rank 0.""" # if torch.distributed.is_initialized(): # if torch.distributed.get_rank() == 0: # print(*message, flush=True) # else: # print(*message, flush=True) def _warmup_mmap_file(path): pass # with open(path, "rb") as stream: # while stream.read(100 * 1024 * 1024): # pass dtypes = { 1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: float, 7: np.double, 8: np.uint16, } def code(dtype): for k in dtypes.keys(): if dtypes[k] == dtype: return k raise ValueError(dtype) def index_file_path(prefix_path): return prefix_path + ".idx" def data_file_path(prefix_path): return prefix_path + ".bin" class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b"MMIDIDX\x00\x00" @classmethod def writer(cls, path, dtype): class _Writer(object): def __enter__(self): self._file = open(path, "wb") # Write Magic string so we can check the file format then opening it again. self._file.write(cls._HDR_MAGIC) # Write version number # Little endian unsigned 64 Bit integer self._file.write(struct.pack("