mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
111 lines
4.3 KiB
Python
111 lines
4.3 KiB
Python
import torch
|
|
from PIL import Image
|
|
from transformers import AutoProcessor, AutoModel
|
|
from typing import List, Union
|
|
import os
|
|
from .config import MODEL_PATHS
|
|
|
|
class PickScore:
|
|
def __init__(self, device: Union[str, torch.device]):
|
|
"""Initialize the Selector with a processor and model.
|
|
|
|
Args:
|
|
device (Union[str, torch.device]): The device to load the model on.
|
|
"""
|
|
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
|
processor_name_or_path = MODEL_PATHS.get("clip")
|
|
model_pretrained_name_or_path = MODEL_PATHS.get("pickscore")
|
|
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
|
|
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
|
|
|
|
def _calculate_score(self, image: torch.Tensor, prompt: str, softmax: bool = False) -> float:
|
|
"""Calculate the score for a single image and prompt.
|
|
|
|
Args:
|
|
image (torch.Tensor): The processed image tensor.
|
|
prompt (str): The prompt text.
|
|
softmax (bool): Whether to apply softmax to the scores.
|
|
|
|
Returns:
|
|
float: The score for the image.
|
|
"""
|
|
with torch.no_grad():
|
|
# Prepare text inputs
|
|
text_inputs = self.processor(
|
|
text=prompt,
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=77,
|
|
return_tensors="pt",
|
|
).to(self.device)
|
|
|
|
# Embed images and text
|
|
image_embs = self.model.get_image_features(pixel_values=image)
|
|
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
|
text_embs = self.model.get_text_features(**text_inputs)
|
|
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
|
|
|
# Compute score
|
|
score = (text_embs @ image_embs.T)[0]
|
|
if softmax:
|
|
# Apply logit scale and softmax
|
|
score = torch.softmax(self.model.logit_scale.exp() * score, dim=-1)
|
|
|
|
return score.cpu().item()
|
|
|
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]:
|
|
"""Score the images based on the prompt.
|
|
|
|
Args:
|
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
|
prompt (str): The prompt text.
|
|
softmax (bool): Whether to apply softmax to the scores.
|
|
|
|
Returns:
|
|
List[float]: List of scores for the images.
|
|
"""
|
|
try:
|
|
if isinstance(images, (str, Image.Image)):
|
|
# Single image
|
|
if isinstance(images, str):
|
|
pil_image = Image.open(images)
|
|
else:
|
|
pil_image = images
|
|
|
|
# Prepare image inputs
|
|
image_inputs = self.processor(
|
|
images=pil_image,
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=77,
|
|
return_tensors="pt",
|
|
).to(self.device)
|
|
|
|
return [self._calculate_score(image_inputs["pixel_values"], prompt, softmax)]
|
|
elif isinstance(images, list):
|
|
# Multiple images
|
|
scores = []
|
|
for one_image in images:
|
|
if isinstance(one_image, str):
|
|
pil_image = Image.open(one_image)
|
|
elif isinstance(one_image, Image.Image):
|
|
pil_image = one_image
|
|
else:
|
|
raise TypeError("The type of parameter images is illegal.")
|
|
|
|
# Prepare image inputs
|
|
image_inputs = self.processor(
|
|
images=pil_image,
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=77,
|
|
return_tensors="pt",
|
|
).to(self.device)
|
|
|
|
scores.append(self._calculate_score(image_inputs["pixel_values"], prompt, softmax))
|
|
return scores
|
|
else:
|
|
raise TypeError("The type of parameter images is illegal.")
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error in scoring images: {e}")
|