mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
DiffSynth-Studio 2.0 major update
This commit is contained in:
3
diffsynth/core/loader/__init__.py
Normal file
3
diffsynth/core/loader/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .file import load_state_dict, hash_state_dict_keys, hash_model_file
|
||||
from .model import load_model, load_model_with_disk_offload
|
||||
from .config import ModelConfig
|
||||
117
diffsynth/core/loader/config.py
Normal file
117
diffsynth/core/loader/config.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch, glob, os
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from modelscope import snapshot_download
|
||||
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
path: Union[str, list[str]] = None
|
||||
model_id: str = None
|
||||
origin_file_pattern: Union[str, list[str]] = None
|
||||
download_source: str = None
|
||||
local_model_path: str = None
|
||||
skip_download: bool = None
|
||||
offload_device: Optional[Union[str, torch.device]] = None
|
||||
offload_dtype: Optional[torch.dtype] = None
|
||||
onload_device: Optional[Union[str, torch.device]] = None
|
||||
onload_dtype: Optional[torch.dtype] = None
|
||||
preparing_device: Optional[Union[str, torch.device]] = None
|
||||
preparing_dtype: Optional[torch.dtype] = None
|
||||
computation_device: Optional[Union[str, torch.device]] = None
|
||||
computation_dtype: Optional[torch.dtype] = None
|
||||
clear_parameters: bool = False
|
||||
|
||||
def check_input(self):
|
||||
if self.path is None and self.model_id is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
||||
|
||||
def parse_original_file_pattern(self):
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
return "*"
|
||||
elif self.origin_file_pattern.endswith("/"):
|
||||
return self.origin_file_pattern + "*"
|
||||
else:
|
||||
return self.origin_file_pattern
|
||||
|
||||
def parse_download_source(self):
|
||||
if self.download_source is None:
|
||||
if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
|
||||
return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
|
||||
else:
|
||||
return "modelscope"
|
||||
else:
|
||||
return self.download_source
|
||||
|
||||
def parse_skip_download(self):
|
||||
if self.skip_download is None:
|
||||
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
|
||||
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
|
||||
return True
|
||||
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return self.skip_download
|
||||
|
||||
def download(self):
|
||||
origin_file_pattern = self.parse_original_file_pattern()
|
||||
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||
download_source = self.parse_download_source()
|
||||
if download_source.lower() == "modelscope":
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_file_pattern=origin_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
elif download_source.lower() == "huggingface":
|
||||
hf_snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_patterns=origin_file_pattern,
|
||||
ignore_patterns=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
else:
|
||||
raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
|
||||
|
||||
def require_downloading(self):
|
||||
if self.path is not None:
|
||||
return False
|
||||
skip_download = self.parse_skip_download()
|
||||
return not skip_download
|
||||
|
||||
def reset_local_model_path(self):
|
||||
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
|
||||
self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
|
||||
elif self.local_model_path is None:
|
||||
self.local_model_path = "./models"
|
||||
|
||||
def download_if_necessary(self):
|
||||
self.check_input()
|
||||
self.reset_local_model_path()
|
||||
if self.require_downloading():
|
||||
self.download()
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||
else:
|
||||
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||
if isinstance(self.path, list) and len(self.path) == 1:
|
||||
self.path = self.path[0]
|
||||
|
||||
def vram_config(self):
|
||||
return {
|
||||
"offload_device": self.offload_device,
|
||||
"offload_dtype": self.offload_dtype,
|
||||
"onload_device": self.onload_device,
|
||||
"onload_dtype": self.onload_dtype,
|
||||
"preparing_device": self.preparing_device,
|
||||
"preparing_dtype": self.preparing_dtype,
|
||||
"computation_device": self.computation_device,
|
||||
"computation_dtype": self.computation_dtype,
|
||||
}
|
||||
121
diffsynth/core/loader/file.py
Normal file
121
diffsynth/core/loader/file.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from safetensors import safe_open
|
||||
import torch, hashlib
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for file_path_ in file_path:
|
||||
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
|
||||
return state_dict
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||
if len(state_dict) == 1:
|
||||
if "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
elif "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
elif "model_state" in state_dict:
|
||||
state_dict = state_dict["model_state"]
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, torch.Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
|
||||
|
||||
def load_keys_dict(file_path):
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for file_path_ in file_path:
|
||||
state_dict.update(load_keys_dict(file_path_))
|
||||
return state_dict
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_keys_dict_from_safetensors(file_path)
|
||||
else:
|
||||
return load_keys_dict_from_bin(file_path)
|
||||
|
||||
|
||||
def load_keys_dict_from_safetensors(file_path):
|
||||
keys_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
keys_dict[k] = f.get_slice(k).get_shape()
|
||||
return keys_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_keys_dict(state_dict):
|
||||
keys_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
keys_dict[k] = list(v.shape)
|
||||
else:
|
||||
keys_dict[k] = convert_state_dict_to_keys_dict(v)
|
||||
return keys_dict
|
||||
|
||||
|
||||
def load_keys_dict_from_bin(file_path):
|
||||
state_dict = load_state_dict_from_bin(file_path)
|
||||
keys_dict = convert_state_dict_to_keys_dict(state_dict)
|
||||
return keys_dict
|
||||
|
||||
|
||||
def convert_keys_dict_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
|
||||
else:
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def hash_model_file(path, with_shape=True):
|
||||
keys_dict = load_keys_dict(path)
|
||||
keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
79
diffsynth/core/loader/model.py
Normal file
79
diffsynth/core/loader/model.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from ..vram.initialization import skip_model_initialization
|
||||
from ..vram.disk_map import DiskMap
|
||||
from ..vram.layers import enable_vram_management
|
||||
from .file import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
||||
config = {} if config is None else config
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||
with skip_model_initialization():
|
||||
model = model_class(**config)
|
||||
# What is `module_map`?
|
||||
# This is a module mapping table for VRAM management.
|
||||
if module_map is not None:
|
||||
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
|
||||
device = [d for d in devices if d != "disk"][0]
|
||||
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||
dtype = [d for d in dtypes if d != "disk"][0]
|
||||
if vram_config["offload_device"] != "disk":
|
||||
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
state_dict = {i: state_dict[i] for i in state_dict}
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
|
||||
else:
|
||||
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
|
||||
else:
|
||||
# Why do we use `DiskMap`?
|
||||
# Sometimes a model file contains multiple models,
|
||||
# and DiskMap can load only the parameters of a single model,
|
||||
# avoiding the need to load all parameters in the file.
|
||||
if use_disk_map:
|
||||
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||
else:
|
||||
state_dict = load_state_dict(path, torch_dtype, device)
|
||||
# Why do we use `state_dict_converter`?
|
||||
# Some models are saved in complex formats,
|
||||
# and we need to convert the state dict into the appropriate format.
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
state_dict = {i: state_dict[i] for i in state_dict}
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
# Why do we call `to()`?
|
||||
# Because some models override the behavior of `to()`,
|
||||
# especially those from libraries like Transformers.
|
||||
model = model.to(dtype=torch_dtype, device=device)
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
|
||||
if isinstance(path, str):
|
||||
path = [path]
|
||||
config = {} if config is None else config
|
||||
with skip_model_initialization():
|
||||
model = model_class(**config)
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": device,
|
||||
"computation_dtype": torch_dtype,
|
||||
"computation_device": device,
|
||||
}
|
||||
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
||||
return model
|
||||
Reference in New Issue
Block a user