mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
wan-series
This commit is contained in:
@@ -59,7 +59,7 @@ class ModelPool:
|
||||
}
|
||||
return vram_config
|
||||
|
||||
def auto_load_model(self, path, vram_config=None, vram_limit=None):
|
||||
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False):
|
||||
print(f"Loading models from: {json.dumps(path, indent=4)}")
|
||||
if vram_config is None:
|
||||
vram_config = self.default_vram_config()
|
||||
@@ -68,6 +68,7 @@ class ModelPool:
|
||||
for config in MODEL_CONFIGS:
|
||||
if config["model_hash"] == model_hash:
|
||||
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit)
|
||||
if clear_parameters: self.clear_parameters(model)
|
||||
self.model.append(model)
|
||||
model_name = config["model_name"]
|
||||
self.model_name.append(model_name)
|
||||
@@ -102,3 +103,9 @@ class ModelPool:
|
||||
model = fetched_models
|
||||
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.")
|
||||
return model
|
||||
|
||||
def clear_parameters(self, model: torch.nn.Module):
|
||||
for name, module in model.named_children():
|
||||
self.clear_parameters(module)
|
||||
for name, param in model.named_parameters(recurse=False):
|
||||
setattr(model, name, None)
|
||||
|
||||
Reference in New Issue
Block a user