mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge branch 'main' into ltx-2
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import torch, glob, os
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Dict
|
||||
from dataclasses import dataclass
|
||||
from modelscope import snapshot_download
|
||||
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||
@@ -23,13 +23,14 @@ class ModelConfig:
|
||||
computation_device: Optional[Union[str, torch.device]] = None
|
||||
computation_dtype: Optional[torch.dtype] = None
|
||||
clear_parameters: bool = False
|
||||
state_dict: Dict[str, torch.Tensor] = None
|
||||
|
||||
def check_input(self):
|
||||
if self.path is None and self.model_id is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
||||
|
||||
def parse_original_file_pattern(self):
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
if self.origin_file_pattern in [None, "", "./"]:
|
||||
return "*"
|
||||
elif self.origin_file_pattern.endswith("/"):
|
||||
return self.origin_file_pattern + "*"
|
||||
@@ -98,7 +99,7 @@ class ModelConfig:
|
||||
if self.require_downloading():
|
||||
self.download()
|
||||
if self.path is None:
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
if self.origin_file_pattern in [None, "", "./"]:
|
||||
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||
else:
|
||||
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||
|
||||
@@ -2,16 +2,25 @@ from safetensors import safe_open
|
||||
import torch, hashlib
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for file_path_ in file_path:
|
||||
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
|
||||
return state_dict
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
if verbose >= 1:
|
||||
print(f"Loading file [started]: {file_path}")
|
||||
if file_path.endswith(".safetensors"):
|
||||
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
else:
|
||||
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
|
||||
if pin_memory:
|
||||
for i in state_dict:
|
||||
state_dict[i] = state_dict[i].pin_memory()
|
||||
if verbose >= 1:
|
||||
print(f"Loading file [done]: {file_path}")
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
|
||||
@@ -5,7 +5,7 @@ from .file import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||
config = {} if config is None else config
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||
dtype = [d for d in dtypes if d != "disk"][0]
|
||||
if vram_config["offload_device"] != "disk":
|
||||
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
@@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
# Sometimes a model file contains multiple models,
|
||||
# and DiskMap can load only the parameters of a single model,
|
||||
# avoiding the need to load all parameters in the file.
|
||||
if use_disk_map:
|
||||
if state_dict is not None:
|
||||
pass
|
||||
elif use_disk_map:
|
||||
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||
else:
|
||||
state_dict = load_state_dict(path, torch_dtype, device)
|
||||
|
||||
Reference in New Issue
Block a user