from safetensors import safe_open import torch, os class SafetensorsCompatibleTensor: def __init__(self, tensor): self.tensor = tensor def get_shape(self): return list(self.tensor.shape) class SafetensorsCompatibleBinaryLoader: def __init__(self, path, device): print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.") self.state_dict = torch.load(path, weights_only=True, map_location=device) def keys(self): return self.state_dict.keys() def get_tensor(self, name): return self.state_dict[name] def get_slice(self, name): return SafetensorsCompatibleTensor(self.state_dict[name]) class DiskMap: def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9): self.path = path if isinstance(path, list) else [path] self.device = device self.torch_dtype = torch_dtype if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None: self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE')) else: self.buffer_size = buffer_size self.files = [] self.flush_files() self.name_map = {} for file_id, file in enumerate(self.files): for name in file.keys(): self.name_map[name] = file_id self.rename_dict = self.fetch_rename_dict(state_dict_converter) def flush_files(self): if len(self.files) == 0: for path in self.path: if path.endswith(".safetensors"): self.files.append(safe_open(path, framework="pt", device=str(self.device))) else: self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device)) else: for i, path in enumerate(self.path): if path.endswith(".safetensors"): self.files[i] = safe_open(path, framework="pt", device=str(self.device)) self.num_params = 0 def __getitem__(self, name): if self.rename_dict is not None: name = self.rename_dict[name] file_id = self.name_map[name] param = self.files[file_id].get_tensor(name) if self.torch_dtype is not None and isinstance(param, torch.Tensor): param = param.to(self.torch_dtype) if param.device == "cpu": param = param.clone() if isinstance(param, torch.Tensor): self.num_params += param.numel() if self.num_params > self.buffer_size: self.flush_files() return param def fetch_rename_dict(self, state_dict_converter): if state_dict_converter is None: return None state_dict = {} for file in self.files: for name in file.keys(): state_dict[name] = name state_dict = state_dict_converter(state_dict) return state_dict def __iter__(self): if self.rename_dict is not None: return self.rename_dict.__iter__() else: return self.name_map.__iter__() def __contains__(self, x): if self.rename_dict is not None: return x in self.rename_dict else: return x in self.name_map