mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
130 lines
5.7 KiB
Python
130 lines
5.7 KiB
Python
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
from tqdm.auto import tqdm
|
|
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
|
|
|
|
from dataclasses import dataclass
|
|
from transformers import CLIPModel as HFCLIPModel
|
|
|
|
from torch import nn, einsum
|
|
|
|
from .trainer.models.base_model import BaseModelConfig
|
|
|
|
from transformers import CLIPConfig
|
|
from transformers import AutoProcessor, AutoModel, AutoTokenizer
|
|
from typing import Any, Optional, Tuple, Union, List
|
|
import torch
|
|
|
|
from .trainer.models.cross_modeling import Cross_model
|
|
import torch.nn.functional as F
|
|
|
|
import gc
|
|
import json
|
|
from .config import MODEL_PATHS
|
|
|
|
class MPScore:
|
|
def __init__(self, device: Union[str, torch.device], condition: str = 'overall'):
|
|
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
|
|
|
Args:
|
|
device (Union[str, torch.device]): The device to load the model on.
|
|
"""
|
|
self.device = device
|
|
processor_name_or_path = MODEL_PATHS.get("clip")
|
|
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
|
|
|
model_ckpt_path = MODEL_PATHS.get("mps")
|
|
self.model = torch.load(model_ckpt_path).eval().to(device)
|
|
self.condition = condition
|
|
|
|
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
|
"""Calculate the reward score for a single image and prompt.
|
|
|
|
Args:
|
|
image (torch.Tensor): The processed image tensor.
|
|
prompt (str): The prompt text.
|
|
|
|
Returns:
|
|
float: The reward score.
|
|
"""
|
|
def _tokenize(caption):
|
|
input_ids = self.tokenizer(
|
|
caption,
|
|
max_length=self.tokenizer.model_max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="pt"
|
|
).input_ids
|
|
return input_ids
|
|
|
|
text_input = _tokenize(prompt).to(self.device)
|
|
if self.condition == 'overall':
|
|
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things'
|
|
elif self.condition == 'aesthetics':
|
|
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry'
|
|
elif self.condition == 'quality':
|
|
condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture'
|
|
elif self.condition == 'semantic':
|
|
condition_prompt = 'quantity, attributes, position, number, location'
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.")
|
|
condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
text_f, text_features = self.model.model.get_text_features(text_input)
|
|
|
|
image_f = self.model.model.get_image_features(image.half())
|
|
condition_f, _ = self.model.model.get_text_features(condition_batch)
|
|
|
|
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
|
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
|
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
|
mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
|
|
mask = mask.repeat(1, image_f.shape[1], 1)
|
|
image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :]
|
|
|
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
|
image_score = self.model.logit_scale.exp() * text_features @ image_features.T
|
|
|
|
return image_score[0].cpu().numpy().item()
|
|
|
|
def score(self, img_path: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
|
"""Score the images based on the prompt.
|
|
|
|
Args:
|
|
img_path (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
|
prompt (str): The prompt text.
|
|
|
|
Returns:
|
|
List[float]: List of reward scores for the images.
|
|
"""
|
|
try:
|
|
if isinstance(img_path, (str, Image.Image)):
|
|
# Single image
|
|
if isinstance(img_path, str):
|
|
image = self.image_processor(Image.open(img_path), return_tensors="pt")["pixel_values"].to(self.device)
|
|
else:
|
|
image = self.image_processor(img_path, return_tensors="pt")["pixel_values"].to(self.device)
|
|
return [self._calculate_score(image, prompt)]
|
|
elif isinstance(img_path, list):
|
|
# Multiple images
|
|
scores = []
|
|
for one_img_path in img_path:
|
|
if isinstance(one_img_path, str):
|
|
image = self.image_processor(Image.open(one_img_path), return_tensors="pt")["pixel_values"].to(self.device)
|
|
elif isinstance(one_img_path, Image.Image):
|
|
image = self.image_processor(one_img_path, return_tensors="pt")["pixel_values"].to(self.device)
|
|
else:
|
|
raise TypeError("The type of parameter img_path is illegal.")
|
|
scores.append(self._calculate_score(image, prompt))
|
|
return scores
|
|
else:
|
|
raise TypeError("The type of parameter img_path is illegal.")
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error in scoring images: {e}")
|