mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
from safetensors import safe_open
|
|
import torch, os
|
|
|
|
|
|
class SafetensorsCompatibleTensor:
|
|
def __init__(self, tensor):
|
|
self.tensor = tensor
|
|
|
|
def get_shape(self):
|
|
return list(self.tensor.shape)
|
|
|
|
|
|
class SafetensorsCompatibleBinaryLoader:
|
|
def __init__(self, path, device):
|
|
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
|
|
self.state_dict = torch.load(path, weights_only=True, map_location=device)
|
|
|
|
def keys(self):
|
|
return self.state_dict.keys()
|
|
|
|
def get_tensor(self, name):
|
|
return self.state_dict[name]
|
|
|
|
def get_slice(self, name):
|
|
return SafetensorsCompatibleTensor(self.state_dict[name])
|
|
|
|
|
|
class DiskMap:
|
|
|
|
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
|
|
self.path = path if isinstance(path, list) else [path]
|
|
self.device = device
|
|
self.torch_dtype = torch_dtype
|
|
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
|
|
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
|
|
else:
|
|
self.buffer_size = buffer_size
|
|
self.files = []
|
|
self.flush_files()
|
|
self.name_map = {}
|
|
for file_id, file in enumerate(self.files):
|
|
for name in file.keys():
|
|
self.name_map[name] = file_id
|
|
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
|
|
|
|
def flush_files(self):
|
|
if len(self.files) == 0:
|
|
for path in self.path:
|
|
if path.endswith(".safetensors"):
|
|
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
|
|
else:
|
|
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
|
|
else:
|
|
for i, path in enumerate(self.path):
|
|
if path.endswith(".safetensors"):
|
|
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
|
|
self.num_params = 0
|
|
|
|
def __getitem__(self, name):
|
|
if self.rename_dict is not None: name = self.rename_dict[name]
|
|
file_id = self.name_map[name]
|
|
param = self.files[file_id].get_tensor(name)
|
|
if self.torch_dtype is not None:
|
|
param = param.to(self.torch_dtype)
|
|
self.num_params += param.numel()
|
|
if self.num_params > self.buffer_size:
|
|
self.flush_files()
|
|
return param
|
|
|
|
def fetch_rename_dict(self, state_dict_converter):
|
|
if state_dict_converter is None:
|
|
return None
|
|
state_dict = {}
|
|
for file in self.files:
|
|
for name in file.keys():
|
|
state_dict[name] = name
|
|
state_dict = state_dict_converter(state_dict)
|
|
return state_dict
|
|
|
|
def __iter__(self):
|
|
if self.rename_dict is not None:
|
|
return self.rename_dict.__iter__()
|
|
else:
|
|
return self.name_map.__iter__()
|
|
|
|
def __contains__(self, x):
|
|
if self.rename_dict is not None:
|
|
return x in self.rename_dict
|
|
else:
|
|
return x in self.name_map
|