add new quality metric

This commit is contained in:
YunhongLu-ZJU
2025-02-17 14:42:20 +08:00
parent 77d0f4d297
commit 991ba162bd
69 changed files with 88 additions and 1461 deletions

View File

@@ -4,10 +4,10 @@ from PIL import Image
from io import BytesIO
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
from transformers import CLIPConfig
from dataclasses import dataclass
from transformers import CLIPModel as HFCLIPModel
from safetensors.torch import load_file
from torch import nn, einsum
from .trainer.models.base_model import BaseModelConfig
@@ -18,26 +18,27 @@ from typing import Any, Optional, Tuple, Union, List
import torch
from .trainer.models.cross_modeling import Cross_model
from .trainer.models import clip_model
import torch.nn.functional as F
import gc
import json
from .config import MODEL_PATHS
class MPScore:
def __init__(self, device: Union[str, torch.device], condition: str = 'overall'):
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
"""Initialize the MPSModel with a processor, tokenizer, and model.
Args:
device (Union[str, torch.device]): The device to load the model on.
"""
self.device = device
processor_name_or_path = MODEL_PATHS.get("clip")
processor_name_or_path = path.get("clip")
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
model_ckpt_path = MODEL_PATHS.get("mps")
self.model = torch.load(model_ckpt_path).eval().to(device)
self.model = clip_model.CLIPModel(processor_name_or_path)
state_dict = load_file(path.get("mps"))
self.model.load_state_dict(state_dict, strict=False)
self.model.to(device)
self.condition = condition
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float: