mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
DiffSynth-Studio 2.0 major update
This commit is contained in:
93
diffsynth/core/vram/disk_map.py
Normal file
93
diffsynth/core/vram/disk_map.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from safetensors import safe_open
|
||||
import torch, os
|
||||
|
||||
|
||||
class SafetensorsCompatibleTensor:
|
||||
def __init__(self, tensor):
|
||||
self.tensor = tensor
|
||||
|
||||
def get_shape(self):
|
||||
return list(self.tensor.shape)
|
||||
|
||||
|
||||
class SafetensorsCompatibleBinaryLoader:
|
||||
def __init__(self, path, device):
|
||||
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
|
||||
self.state_dict = torch.load(path, weights_only=True, map_location=device)
|
||||
|
||||
def keys(self):
|
||||
return self.state_dict.keys()
|
||||
|
||||
def get_tensor(self, name):
|
||||
return self.state_dict[name]
|
||||
|
||||
def get_slice(self, name):
|
||||
return SafetensorsCompatibleTensor(self.state_dict[name])
|
||||
|
||||
|
||||
class DiskMap:
|
||||
|
||||
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
|
||||
self.path = path if isinstance(path, list) else [path]
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
|
||||
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
|
||||
else:
|
||||
self.buffer_size = buffer_size
|
||||
self.files = []
|
||||
self.flush_files()
|
||||
self.name_map = {}
|
||||
for file_id, file in enumerate(self.files):
|
||||
for name in file.keys():
|
||||
self.name_map[name] = file_id
|
||||
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
|
||||
|
||||
def flush_files(self):
|
||||
if len(self.files) == 0:
|
||||
for path in self.path:
|
||||
if path.endswith(".safetensors"):
|
||||
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
|
||||
else:
|
||||
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
|
||||
else:
|
||||
for i, path in enumerate(self.path):
|
||||
if path.endswith(".safetensors"):
|
||||
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
|
||||
self.num_params = 0
|
||||
|
||||
def __getitem__(self, name):
|
||||
if self.rename_dict is not None: name = self.rename_dict[name]
|
||||
file_id = self.name_map[name]
|
||||
param = self.files[file_id].get_tensor(name)
|
||||
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
|
||||
param = param.to(self.torch_dtype)
|
||||
if isinstance(param, torch.Tensor) and param.device == "cpu":
|
||||
param = param.clone()
|
||||
if isinstance(param, torch.Tensor):
|
||||
self.num_params += param.numel()
|
||||
if self.num_params > self.buffer_size:
|
||||
self.flush_files()
|
||||
return param
|
||||
|
||||
def fetch_rename_dict(self, state_dict_converter):
|
||||
if state_dict_converter is None:
|
||||
return None
|
||||
state_dict = {}
|
||||
for file in self.files:
|
||||
for name in file.keys():
|
||||
state_dict[name] = name
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
return state_dict
|
||||
|
||||
def __iter__(self):
|
||||
if self.rename_dict is not None:
|
||||
return self.rename_dict.__iter__()
|
||||
else:
|
||||
return self.name_map.__iter__()
|
||||
|
||||
def __contains__(self, x):
|
||||
if self.rename_dict is not None:
|
||||
return x in self.rename_dict
|
||||
else:
|
||||
return x in self.name_map
|
||||
Reference in New Issue
Block a user