support wan-series models

This commit is contained in:
Artiprocher
2025-11-13 17:30:19 +08:00
parent cb70126c88
commit 5be5c32fe4
64 changed files with 9915 additions and 24 deletions

View File

@@ -27,7 +27,8 @@ class ModelConfig:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
def download(self):
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
origin_file_pattern = self.origin_file_pattern + ("*" if self.origin_file_pattern.endswith("/") else "")
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),

View File

@@ -31,6 +31,8 @@ def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
state_dict = state_dict["state_dict"]
elif "module" in state_dict:
state_dict = state_dict["module"]
elif "model_state" in state_dict:
state_dict = state_dict["model_state"]
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):

View File

@@ -28,7 +28,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
# and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file.
if use_disk_map:
state_dict = DiskMap(path, device)
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else:
state_dict = load_state_dict(path, torch_dtype, device)
# Why do we use `state_dict_converter`?