support wan-series models

This commit is contained in:
Artiprocher
2025-11-13 17:30:19 +08:00
parent cb70126c88
commit 5be5c32fe4
64 changed files with 9915 additions and 24 deletions

View File

@@ -28,7 +28,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
# and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file.
if use_disk_map:
state_dict = DiskMap(path, device)
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else:
state_dict = load_state_dict(path, torch_dtype, device)
# Why do we use `state_dict_converter`?