mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
add new quality metric
This commit is contained in:
@@ -4,10 +4,10 @@ from PIL import Image
|
||||
from io import BytesIO
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
|
||||
|
||||
from transformers import CLIPConfig
|
||||
from dataclasses import dataclass
|
||||
from transformers import CLIPModel as HFCLIPModel
|
||||
|
||||
from safetensors.torch import load_file
|
||||
from torch import nn, einsum
|
||||
|
||||
from .trainer.models.base_model import BaseModelConfig
|
||||
@@ -18,26 +18,27 @@ from typing import Any, Optional, Tuple, Union, List
|
||||
import torch
|
||||
|
||||
from .trainer.models.cross_modeling import Cross_model
|
||||
from .trainer.models import clip_model
|
||||
import torch.nn.functional as F
|
||||
|
||||
import gc
|
||||
import json
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class MPScore:
|
||||
def __init__(self, device: Union[str, torch.device], condition: str = 'overall'):
|
||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
|
||||
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
"""
|
||||
self.device = device
|
||||
processor_name_or_path = MODEL_PATHS.get("clip")
|
||||
processor_name_or_path = path.get("clip")
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
||||
|
||||
model_ckpt_path = MODEL_PATHS.get("mps")
|
||||
self.model = torch.load(model_ckpt_path).eval().to(device)
|
||||
self.model = clip_model.CLIPModel(processor_name_or_path)
|
||||
state_dict = load_file(path.get("mps"))
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
self.model.to(device)
|
||||
self.condition = condition
|
||||
|
||||
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
||||
|
||||
Reference in New Issue
Block a user