mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
add new quality metric
This commit is contained in:
@@ -7,7 +7,7 @@ import os
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class HPScore_v2:
|
||||
def __init__(self, device: torch.device, model_version: str = "v2"):
|
||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
|
||||
"""Initialize the Selector with a model and tokenizer.
|
||||
|
||||
Args:
|
||||
@@ -17,9 +17,9 @@ class HPScore_v2:
|
||||
self.device = device
|
||||
|
||||
if model_version == "v2":
|
||||
safetensors_path = MODEL_PATHS.get("hpsv2")
|
||||
safetensors_path = path.get("hpsv2")
|
||||
elif model_version == "v21":
|
||||
safetensors_path = MODEL_PATHS.get("hpsv2.1")
|
||||
safetensors_path = path.get("hpsv2.1")
|
||||
else:
|
||||
raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
|
||||
|
||||
@@ -27,7 +27,7 @@ class HPScore_v2:
|
||||
model, _, self.preprocess_val = create_model_and_transforms(
|
||||
"ViT-H-14",
|
||||
# "laion2B-s32B-b79K",
|
||||
pretrained=MODEL_PATHS.get("open_clip"),
|
||||
pretrained=path.get("open_clip"),
|
||||
precision="amp",
|
||||
device=device,
|
||||
jit=False,
|
||||
|
||||
Reference in New Issue
Block a user