add new quality metric

This commit is contained in:
YunhongLu-ZJU
2025-02-17 14:42:20 +08:00
parent 77d0f4d297
commit 991ba162bd
69 changed files with 88 additions and 1461 deletions

View File

@@ -6,15 +6,15 @@ import os
from .config import MODEL_PATHS
class PickScore:
def __init__(self, device: Union[str, torch.device]):
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
"""Initialize the Selector with a processor and model.
Args:
device (Union[str, torch.device]): The device to load the model on.
"""
self.device = device if isinstance(device, torch.device) else torch.device(device)
processor_name_or_path = MODEL_PATHS.get("clip")
model_pretrained_name_or_path = MODEL_PATHS.get("pickscore")
processor_name_or_path = path.get("clip")
model_pretrained_name_or_path = path.get("pickscore")
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)