Compare commits

...

3 Commits

Author SHA1 Message Date
Artiprocher
53fe42af1b update version 2026-01-30 13:49:27 +08:00
Artiprocher
ee9a3b4405 support loading models from state dict 2026-01-30 13:47:36 +08:00
Zhongjie Duan
22695e9be0 Merge pull request #1233 from modelscope/z-image-release
Z-Image and Z-Image-i2L
2026-01-27 18:41:28 +08:00
6 changed files with 28 additions and 14 deletions

View File

@@ -1,5 +1,5 @@
import torch, glob, os import torch, glob, os
from typing import Optional, Union from typing import Optional, Union, Dict
from dataclasses import dataclass from dataclasses import dataclass
from modelscope import snapshot_download from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download from huggingface_hub import snapshot_download as hf_snapshot_download
@@ -23,6 +23,7 @@ class ModelConfig:
computation_device: Optional[Union[str, torch.device]] = None computation_device: Optional[Union[str, torch.device]] = None
computation_dtype: Optional[torch.dtype] = None computation_dtype: Optional[torch.dtype] = None
clear_parameters: bool = False clear_parameters: bool = False
state_dict: Dict[str, torch.Tensor] = None
def check_input(self): def check_input(self):
if self.path is None and self.model_id is None: if self.path is None and self.model_id is None:

View File

@@ -2,16 +2,25 @@ from safetensors import safe_open
import torch, hashlib import torch, hashlib
def load_state_dict(file_path, torch_dtype=None, device="cpu"): def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
if isinstance(file_path, list): if isinstance(file_path, list):
state_dict = {} state_dict = {}
for file_path_ in file_path: for file_path_ in file_path:
state_dict.update(load_state_dict(file_path_, torch_dtype, device)) state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
return state_dict
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else: else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) if verbose >= 1:
print(f"Loading file [started]: {file_path}")
if file_path.endswith(".safetensors"):
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
if pin_memory:
for i in state_dict:
state_dict[i] = state_dict[i].pin_memory()
if verbose >= 1:
print(f"Loading file [done]: {file_path}")
return state_dict
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):

View File

@@ -5,7 +5,7 @@ from .file import load_state_dict
import torch import torch
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
config = {} if config is None else config config = {} if config is None else config
# Why do we use `skip_model_initialization`? # Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters, # It skips the random initialization of model parameters,
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0] dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk": if vram_config["offload_device"] != "disk":
state_dict = DiskMap(path, device, torch_dtype=dtype) if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict_converter is not None: if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict) state_dict = state_dict_converter(state_dict)
else: else:
@@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
# Sometimes a model file contains multiple models, # Sometimes a model file contains multiple models,
# and DiskMap can load only the parameters of a single model, # and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file. # avoiding the need to load all parameters in the file.
if use_disk_map: if state_dict is not None:
pass
elif use_disk_map:
state_dict = DiskMap(path, device, torch_dtype=torch_dtype) state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else: else:
state_dict = load_state_dict(path, torch_dtype, device) state_dict = load_state_dict(path, torch_dtype, device)

View File

@@ -296,6 +296,7 @@ class BasePipeline(torch.nn.Module):
vram_config=vram_config, vram_config=vram_config,
vram_limit=vram_limit, vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters, clear_parameters=model_config.clear_parameters,
state_dict=model_config.state_dict,
) )
return model_pool return model_pool

View File

@@ -29,7 +29,7 @@ class ModelPool:
module_map = None module_map = None
return module_map return module_map
def load_model_file(self, config, path, vram_config, vram_limit=None): def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
model_class = self.import_model_class(config["model_class"]) model_class = self.import_model_class(config["model_class"])
model_config = config.get("extra_kwargs", {}) model_config = config.get("extra_kwargs", {})
if "state_dict_converter" in config: if "state_dict_converter" in config:
@@ -43,6 +43,7 @@ class ModelPool:
state_dict_converter, state_dict_converter,
use_disk_map=True, use_disk_map=True,
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit, vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
state_dict=state_dict,
) )
return model return model
@@ -59,7 +60,7 @@ class ModelPool:
} }
return vram_config return vram_config
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False): def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
print(f"Loading models from: {json.dumps(path, indent=4)}") print(f"Loading models from: {json.dumps(path, indent=4)}")
if vram_config is None: if vram_config is None:
vram_config = self.default_vram_config() vram_config = self.default_vram_config()
@@ -67,7 +68,7 @@ class ModelPool:
loaded = False loaded = False
for config in MODEL_CONFIGS: for config in MODEL_CONFIGS:
if config["model_hash"] == model_hash: if config["model_hash"] == model_hash:
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit) model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
if clear_parameters: self.clear_parameters(model) if clear_parameters: self.clear_parameters(model)
self.model.append(model) self.model.append(model)
model_name = config["model_name"] model_name = config["model_name"]

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "diffsynth" name = "diffsynth"
version = "2.0.3" version = "2.0.4"
description = "Enjoy the magic of Diffusion models!" description = "Enjoy the magic of Diffusion models!"
authors = [{name = "ModelScope Team"}] authors = [{name = "ModelScope Team"}]
license = {text = "Apache-2.0"} license = {text = "Apache-2.0"}