mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
add new quality metric
This commit is contained in:
@@ -50,31 +50,30 @@ class MLP(torch.nn.Module):
|
||||
|
||||
|
||||
class AestheticScore:
|
||||
def __init__(self, device: torch.device, model_path: str = MODEL_PATHS.get("aesthetic_predictor")):
|
||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||
"""Initialize the Selector with a model and processor.
|
||||
|
||||
Args:
|
||||
device (torch.device): The device to load the model on.
|
||||
model_path (str): Path to the model weights file.
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
self.aes_model_path = path.get("aesthetic_predictor")
|
||||
# Load the MLP model
|
||||
self.model = MLP(768)
|
||||
try:
|
||||
if model_path.endswith(".safetensors"):
|
||||
state_dict = load_file(model_path)
|
||||
if self.aes_model_path.endswith(".safetensors"):
|
||||
state_dict = load_file(self.aes_model_path)
|
||||
else:
|
||||
state_dict = torch.load(model_path)
|
||||
state_dict = torch.load(self.aes_model_path)
|
||||
self.model.load_state_dict(state_dict)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading model weights from {model_path}: {e}")
|
||||
raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
|
||||
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
|
||||
# Load the CLIP model and processor
|
||||
clip_model_name = MODEL_PATHS.get('clip-large')
|
||||
clip_model_name = path.get('clip-large')
|
||||
self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
|
||||
self.processor = AutoProcessor.from_pretrained(clip_model_name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user