add new quality metric

This commit is contained in:
YunhongLu-ZJU
2025-02-17 14:42:20 +08:00
parent 77d0f4d297
commit 991ba162bd
69 changed files with 88 additions and 1461 deletions

View File

@@ -5,7 +5,7 @@ from .open_clip import create_model_and_transforms, get_tokenizer
from .config import MODEL_PATHS
class CLIPScore:
def __init__(self, device: torch.device):
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
"""Initialize the CLIPScore with a model and tokenizer.
Args:
@@ -17,7 +17,7 @@ class CLIPScore:
self.model, _, self.preprocess_val = create_model_and_transforms(
"ViT-H-14",
# "laion2B-s32B-b79K",
pretrained=MODEL_PATHS.get("open_clip"),
pretrained=path.get("open_clip"),
precision="amp",
device=device,
jit=False,