mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
23 lines
780 B
Python
23 lines
780 B
Python
import torch
|
|
from diffsynth.models.model_loader import ModelPool
|
|
from diffsynth.core.loader import ModelConfig
|
|
|
|
|
|
def test_model_loading(model_name,
|
|
model_config: ModelConfig,
|
|
vram_limit: float = None,
|
|
device="cpu",
|
|
torch_dtype=torch.bfloat16):
|
|
model_pool = ModelPool()
|
|
model_config.download_if_necessary()
|
|
vram_config = model_config.vram_config()
|
|
vram_config["computation_dtype"] = torch_dtype
|
|
vram_config["computation_device"] = device
|
|
model_pool.auto_load_model(
|
|
model_config.path,
|
|
vram_config=vram_config,
|
|
vram_limit=vram_limit,
|
|
clear_parameters=model_config.clear_parameters,
|
|
)
|
|
return model_pool.fetch_model(model_name)
|