add new quality metric

This commit is contained in:
YunhongLu-ZJU
2025-02-17 14:42:20 +08:00
parent 77d0f4d297
commit 991ba162bd
69 changed files with 88 additions and 1461 deletions

View File

@@ -188,15 +188,15 @@ class ImageReward(torch.nn.Module):
class ImageRewardScore:
def __init__(self, device: Union[str, torch.device]):
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
"""Initialize the Selector with a processor and model.
Args:
device (Union[str, torch.device]): The device to load the model on.
"""
self.device = device if isinstance(device, torch.device) else torch.device(device)
model_path = MODEL_PATHS.get("imagereward")
med_config = MODEL_PATHS.get("med_config")
model_path = path.get("imagereward")
med_config = path.get("med_config")
state_dict = load_file(model_path)
self.model = ImageReward(device=self.device, med_config=med_config).to(self.device)
self.model.load_state_dict(state_dict, strict=False)