import numpy as np import torch from PIL import Image from io import BytesIO from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPImageProcessor from dataclasses import dataclass from transformers import CLIPModel as HFCLIPModel from torch import nn, einsum from .trainer.models.base_model import BaseModelConfig from transformers import CLIPConfig from transformers import AutoProcessor, AutoModel, AutoTokenizer from typing import Any, Optional, Tuple, Union, List import torch from .trainer.models.cross_modeling import Cross_model import torch.nn.functional as F import gc import json from .config import MODEL_PATHS class MPScore: def __init__(self, device: Union[str, torch.device], condition: str = 'overall'): """Initialize the MPSModel with a processor, tokenizer, and model. Args: device (Union[str, torch.device]): The device to load the model on. """ self.device = device processor_name_or_path = MODEL_PATHS.get("clip") self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True) model_ckpt_path = MODEL_PATHS.get("mps") self.model = torch.load(model_ckpt_path).eval().to(device) self.condition = condition def _calculate_score(self, image: torch.Tensor, prompt: str) -> float: """Calculate the reward score for a single image and prompt. Args: image (torch.Tensor): The processed image tensor. prompt (str): The prompt text. Returns: float: The reward score. """ def _tokenize(caption): input_ids = self.tokenizer( caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids return input_ids text_input = _tokenize(prompt).to(self.device) if self.condition == 'overall': condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things' elif self.condition == 'aesthetics': condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry' elif self.condition == 'quality': condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture' elif self.condition == 'semantic': condition_prompt = 'quantity, attributes, position, number, location' else: raise ValueError( f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.") condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device) with torch.no_grad(): text_f, text_features = self.model.model.get_text_features(text_input) image_f = self.model.model.get_image_features(image.half()) condition_f, _ = self.model.model.get_text_features(condition_batch) sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f) sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0] sim_text_condition = sim_text_condition / sim_text_condition.max() mask = torch.where(sim_text_condition > 0.3, 0, float('-inf')) mask = mask.repeat(1, image_f.shape[1], 1) image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :] image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) image_score = self.model.logit_scale.exp() * text_features @ image_features.T return image_score[0].cpu().numpy().item() def score(self, img_path: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]: """Score the images based on the prompt. Args: img_path (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s). prompt (str): The prompt text. Returns: List[float]: List of reward scores for the images. """ try: if isinstance(img_path, (str, Image.Image)): # Single image if isinstance(img_path, str): image = self.image_processor(Image.open(img_path), return_tensors="pt")["pixel_values"].to(self.device) else: image = self.image_processor(img_path, return_tensors="pt")["pixel_values"].to(self.device) return [self._calculate_score(image, prompt)] elif isinstance(img_path, list): # Multiple images scores = [] for one_img_path in img_path: if isinstance(one_img_path, str): image = self.image_processor(Image.open(one_img_path), return_tensors="pt")["pixel_values"].to(self.device) elif isinstance(one_img_path, Image.Image): image = self.image_processor(one_img_path, return_tensors="pt")["pixel_values"].to(self.device) else: raise TypeError("The type of parameter img_path is illegal.") scores.append(self._calculate_score(image, prompt)) return scores else: raise TypeError("The type of parameter img_path is illegal.") except Exception as e: raise RuntimeError(f"Error in scoring images: {e}")