mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
support wan-series models
This commit is contained in:
@@ -27,7 +27,8 @@ class ModelConfig:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
||||
|
||||
def download(self):
|
||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||
origin_file_pattern = self.origin_file_pattern + ("*" if self.origin_file_pattern.endswith("/") else "")
|
||||
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
|
||||
@@ -31,6 +31,8 @@ def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = state_dict["state_dict"]
|
||||
elif "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
elif "model_state" in state_dict:
|
||||
state_dict = state_dict["model_state"]
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
|
||||
@@ -28,7 +28,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
# and DiskMap can load only the parameters of a single model,
|
||||
# avoiding the need to load all parameters in the file.
|
||||
if use_disk_map:
|
||||
state_dict = DiskMap(path, device)
|
||||
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||
else:
|
||||
state_dict = load_state_dict(path, torch_dtype, device)
|
||||
# Why do we use `state_dict_converter`?
|
||||
|
||||
Reference in New Issue
Block a user