mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
add new quality metric
This commit is contained in:
@@ -188,15 +188,15 @@ class ImageReward(torch.nn.Module):
|
||||
|
||||
|
||||
class ImageRewardScore:
|
||||
def __init__(self, device: Union[str, torch.device]):
|
||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||
"""Initialize the Selector with a processor and model.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
"""
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
model_path = MODEL_PATHS.get("imagereward")
|
||||
med_config = MODEL_PATHS.get("med_config")
|
||||
model_path = path.get("imagereward")
|
||||
med_config = path.get("med_config")
|
||||
state_dict = load_file(model_path)
|
||||
self.model = ImageReward(device=self.device, med_config=med_config).to(self.device)
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
Reference in New Issue
Block a user