add new quality metric

This commit is contained in:
YunhongLu-ZJU
2025-02-17 14:42:20 +08:00
parent 77d0f4d297
commit 991ba162bd
69 changed files with 88 additions and 1461 deletions

View File

@@ -7,7 +7,7 @@ import os
from .config import MODEL_PATHS
class HPScore_v2:
def __init__(self, device: torch.device, model_version: str = "v2"):
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
"""Initialize the Selector with a model and tokenizer.
Args:
@@ -17,9 +17,9 @@ class HPScore_v2:
self.device = device
if model_version == "v2":
safetensors_path = MODEL_PATHS.get("hpsv2")
safetensors_path = path.get("hpsv2")
elif model_version == "v21":
safetensors_path = MODEL_PATHS.get("hpsv2.1")
safetensors_path = path.get("hpsv2.1")
else:
raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
@@ -27,7 +27,7 @@ class HPScore_v2:
model, _, self.preprocess_val = create_model_and_transforms(
"ViT-H-14",
# "laion2B-s32B-b79K",
pretrained=MODEL_PATHS.get("open_clip"),
pretrained=path.get("open_clip"),
precision="amp",
device=device,
jit=False,