add new quality metric

This commit is contained in:
YunhongLu-ZJU
2025-02-17 14:42:20 +08:00
parent 77d0f4d297
commit 991ba162bd
69 changed files with 88 additions and 1461 deletions

View File

@@ -50,31 +50,30 @@ class MLP(torch.nn.Module):
class AestheticScore:
def __init__(self, device: torch.device, model_path: str = MODEL_PATHS.get("aesthetic_predictor")):
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
"""Initialize the Selector with a model and processor.
Args:
device (torch.device): The device to load the model on.
model_path (str): Path to the model weights file.
"""
self.device = device
self.aes_model_path = path.get("aesthetic_predictor")
# Load the MLP model
self.model = MLP(768)
try:
if model_path.endswith(".safetensors"):
state_dict = load_file(model_path)
if self.aes_model_path.endswith(".safetensors"):
state_dict = load_file(self.aes_model_path)
else:
state_dict = torch.load(model_path)
state_dict = torch.load(self.aes_model_path)
self.model.load_state_dict(state_dict)
except Exception as e:
raise ValueError(f"Error loading model weights from {model_path}: {e}")
raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
self.model.to(device)
self.model.eval()
# Load the CLIP model and processor
clip_model_name = MODEL_PATHS.get('clip-large')
clip_model_name = path.get('clip-large')
self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
self.processor = AutoProcessor.from_pretrained(clip_model_name)