mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
update preference models
This commit is contained in:
@@ -19,9 +19,8 @@ def default_bert():
|
|||||||
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||||
return os.path.join(model_path, "bert-base-uncased")
|
return os.path.join(model_path, "bert-base-uncased")
|
||||||
|
|
||||||
bert_model_path = default_bert()
|
|
||||||
|
|
||||||
def init_tokenizer():
|
def init_tokenizer(bert_model_path):
|
||||||
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
||||||
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
||||||
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
||||||
@@ -20,6 +20,7 @@ class BLIP_Pretrain(nn.Module):
|
|||||||
embed_dim = 256,
|
embed_dim = 256,
|
||||||
queue_size = 57600,
|
queue_size = 57600,
|
||||||
momentum = 0.995,
|
momentum = 0.995,
|
||||||
|
bert_model_path = ""
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -31,7 +32,7 @@ class BLIP_Pretrain(nn.Module):
|
|||||||
|
|
||||||
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
||||||
|
|
||||||
self.tokenizer = init_tokenizer()
|
self.tokenizer = init_tokenizer(bert_model_path)
|
||||||
encoder_config = BertConfig.from_json_file(med_config)
|
encoder_config = BertConfig.from_json_file(med_config)
|
||||||
encoder_config.encoder_width = vision_width
|
encoder_config.encoder_width = vision_width
|
||||||
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
||||||
@@ -14,7 +14,7 @@ from timm.models.registry import register_model
|
|||||||
from timm.models.layers import trunc_normal_, DropPath
|
from timm.models.layers import trunc_normal_, DropPath
|
||||||
from timm.models.helpers import named_apply, adapt_input_conv
|
from timm.models.helpers import named_apply, adapt_input_conv
|
||||||
|
|
||||||
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
# from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
class Mlp(nn.Module):
|
||||||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||||
@@ -96,9 +96,9 @@ class Block(nn.Module):
|
|||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
if use_grad_checkpointing:
|
# if use_grad_checkpointing:
|
||||||
self.attn = checkpoint_wrapper(self.attn)
|
# self.attn = checkpoint_wrapper(self.attn)
|
||||||
self.mlp = checkpoint_wrapper(self.mlp)
|
# self.mlp = checkpoint_wrapper(self.mlp)
|
||||||
|
|
||||||
def forward(self, x, register_hook=False):
|
def forward(self, x, register_hook=False):
|
||||||
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
||||||
148
diffsynth/extensions/ImageQualityMetric/__init__.py
Normal file
148
diffsynth/extensions/ImageQualityMetric/__init__.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from modelscope import snapshot_download
|
||||||
|
from typing_extensions import Literal, TypeAlias
|
||||||
|
import os
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2
|
||||||
|
from diffsynth.extensions.ImageQualityMetric.mps import MPScore
|
||||||
|
|
||||||
|
|
||||||
|
preference_model_id: TypeAlias = Literal[
|
||||||
|
"ImageReward",
|
||||||
|
"Aesthetic",
|
||||||
|
"PickScore",
|
||||||
|
"CLIP",
|
||||||
|
"HPSv2",
|
||||||
|
"HPSv2.1",
|
||||||
|
"MPS",
|
||||||
|
]
|
||||||
|
model_dict = {
|
||||||
|
"ImageReward": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"ImageReward/ImageReward.safetensors",
|
||||||
|
"ImageReward/med_config.json",
|
||||||
|
"bert-base-uncased/config.json",
|
||||||
|
"bert-base-uncased/model.safetensors",
|
||||||
|
"bert-base-uncased/tokenizer.json",
|
||||||
|
"bert-base-uncased/tokenizer_config.json",
|
||||||
|
"bert-base-uncased/vocab.txt",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"imagereward": "ImageReward/ImageReward.safetensors",
|
||||||
|
"med_config": "ImageReward/med_config.json",
|
||||||
|
"bert_model_path": "bert-base-uncased",
|
||||||
|
},
|
||||||
|
"model_class": ImageRewardScore
|
||||||
|
},
|
||||||
|
"Aesthetic": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
||||||
|
"clip-vit-large-patch14/config.json",
|
||||||
|
"clip-vit-large-patch14/merges.txt",
|
||||||
|
"clip-vit-large-patch14/model.safetensors",
|
||||||
|
"clip-vit-large-patch14/preprocessor_config.json",
|
||||||
|
"clip-vit-large-patch14/special_tokens_map.json",
|
||||||
|
"clip-vit-large-patch14/tokenizer.json",
|
||||||
|
"clip-vit-large-patch14/tokenizer_config.json",
|
||||||
|
"clip-vit-large-patch14/vocab.json",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
||||||
|
"clip-large": "clip-vit-large-patch14",
|
||||||
|
},
|
||||||
|
"model_class": AestheticScore
|
||||||
|
},
|
||||||
|
"PickScore": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"PickScore_v1/*",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"pickscore": "PickScore_v1",
|
||||||
|
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
||||||
|
},
|
||||||
|
"model_class": PickScore
|
||||||
|
},
|
||||||
|
"CLIP": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
||||||
|
"bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
||||||
|
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
},
|
||||||
|
"model_class": CLIPScore
|
||||||
|
},
|
||||||
|
"HPSv2": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"HPS_v2/HPS_v2_compressed.safetensors",
|
||||||
|
"bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"hpsv2": "HPS_v2/HPS_v2_compressed.safetensors",
|
||||||
|
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
},
|
||||||
|
"model_class": HPScore_v2,
|
||||||
|
"extra_kwargs": {"model_version": "v2"}
|
||||||
|
},
|
||||||
|
"HPSv2.1": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"HPS_v2/HPS_v2.1_compressed.safetensors",
|
||||||
|
"bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors",
|
||||||
|
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
},
|
||||||
|
"model_class": HPScore_v2,
|
||||||
|
"extra_kwargs": {"model_version": "v21"}
|
||||||
|
},
|
||||||
|
"MPS": {
|
||||||
|
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||||
|
"allow_file_pattern": [
|
||||||
|
"MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
||||||
|
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
||||||
|
],
|
||||||
|
"load_path": {
|
||||||
|
"mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
||||||
|
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
||||||
|
},
|
||||||
|
"model_class": MPScore
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def download_preference_model(model_name: preference_model_id, cache_dir="models"):
|
||||||
|
metadata = model_dict[model_name]
|
||||||
|
snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir)
|
||||||
|
load_path = metadata["load_path"]
|
||||||
|
load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()}
|
||||||
|
return load_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_preference_model(model_name: preference_model_id, device = "cuda", path = None):
|
||||||
|
model_class = model_dict[model_name]["model_class"]
|
||||||
|
extra_kwargs = model_dict[model_name].get("extra_kwargs", {})
|
||||||
|
preference_model = model_class(device=device, path=path, **extra_kwargs)
|
||||||
|
return preference_model
|
||||||
@@ -49,13 +49,9 @@ class MLP(torch.nn.Module):
|
|||||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
|
||||||
class AestheticScore:
|
class AestheticScore(torch.nn.Module):
|
||||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||||
"""Initialize the Selector with a model and processor.
|
super().__init__()
|
||||||
|
|
||||||
Args:
|
|
||||||
device (torch.device): The device to load the model on.
|
|
||||||
"""
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.aes_model_path = path.get("aesthetic_predictor")
|
self.aes_model_path = path.get("aesthetic_predictor")
|
||||||
# Load the MLP model
|
# Load the MLP model
|
||||||
@@ -96,7 +92,8 @@ class AestheticScore:
|
|||||||
|
|
||||||
return score
|
return score
|
||||||
|
|
||||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]]) -> List[float]:
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
||||||
"""Score the images based on their aesthetic quality.
|
"""Score the images based on their aesthetic quality.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -4,8 +4,9 @@ import torch
|
|||||||
from .open_clip import create_model_and_transforms, get_tokenizer
|
from .open_clip import create_model_and_transforms, get_tokenizer
|
||||||
from .config import MODEL_PATHS
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
class CLIPScore:
|
class CLIPScore(torch.nn.Module):
|
||||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||||
|
super().__init__()
|
||||||
"""Initialize the CLIPScore with a model and tokenizer.
|
"""Initialize the CLIPScore with a model and tokenizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -36,7 +37,7 @@ class CLIPScore:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize tokenizer
|
# Initialize tokenizer
|
||||||
self.tokenizer = get_tokenizer("ViT-H-14")
|
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
@@ -62,37 +63,35 @@ class CLIPScore:
|
|||||||
|
|
||||||
return clip_score[0].item()
|
return clip_score[0].item()
|
||||||
|
|
||||||
def score(self, img_path: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||||
"""Score the images based on the prompt.
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_path (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
prompt (str): The prompt text.
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[float]: List of CLIP scores for the images.
|
List[float]: List of CLIP scores for the images.
|
||||||
"""
|
"""
|
||||||
try:
|
if isinstance(images, (str, Image.Image)):
|
||||||
if isinstance(img_path, (str, Image.Image)):
|
# Single image
|
||||||
# Single image
|
if isinstance(images, str):
|
||||||
if isinstance(img_path, str):
|
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
image = self.preprocess_val(Image.open(img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
|
||||||
else:
|
|
||||||
image = self.preprocess_val(img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
|
|
||||||
return [self._calculate_score(image, prompt)]
|
|
||||||
elif isinstance(img_path, list):
|
|
||||||
# Multiple images
|
|
||||||
scores = []
|
|
||||||
for one_img_path in img_path:
|
|
||||||
if isinstance(one_img_path, str):
|
|
||||||
image = self.preprocess_val(Image.open(one_img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
|
||||||
elif isinstance(one_img_path, Image.Image):
|
|
||||||
image = self.preprocess_val(one_img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter img_path is illegal.")
|
|
||||||
scores.append(self._calculate_score(image, prompt))
|
|
||||||
return scores
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("The type of parameter img_path is illegal.")
|
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
except Exception as e:
|
return [self._calculate_score(image, prompt)]
|
||||||
raise RuntimeError(f"Error in scoring images: {e}")
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_images in images:
|
||||||
|
if isinstance(one_images, str):
|
||||||
|
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
elif isinstance(one_images, Image.Image):
|
||||||
|
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
scores.append(self._calculate_score(image, prompt))
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
@@ -6,8 +6,9 @@ from safetensors.torch import load_file
|
|||||||
import os
|
import os
|
||||||
from .config import MODEL_PATHS
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
class HPScore_v2:
|
class HPScore_v2(torch.nn.Module):
|
||||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
|
||||||
|
super().__init__()
|
||||||
"""Initialize the Selector with a model and tokenizer.
|
"""Initialize the Selector with a model and tokenizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -53,7 +54,7 @@ class HPScore_v2:
|
|||||||
raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
|
raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
|
||||||
|
|
||||||
# Initialize tokenizer and model
|
# Initialize tokenizer and model
|
||||||
self.tokenizer = get_tokenizer("ViT-H-14")
|
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -80,37 +81,38 @@ class HPScore_v2:
|
|||||||
|
|
||||||
return hps_score[0].item()
|
return hps_score[0].item()
|
||||||
|
|
||||||
def score(self, img_path: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||||
"""Score the images based on the prompt.
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_path (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
prompt (str): The prompt text.
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[float]: List of HPS scores for the images.
|
List[float]: List of HPS scores for the images.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if isinstance(img_path, (str, Image.Image)):
|
if isinstance(images, (str, Image.Image)):
|
||||||
# Single image
|
# Single image
|
||||||
if isinstance(img_path, str):
|
if isinstance(images, str):
|
||||||
image = self.preprocess_val(Image.open(img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
else:
|
else:
|
||||||
image = self.preprocess_val(img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
|
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
return [self._calculate_score(image, prompt)]
|
return [self._calculate_score(image, prompt)]
|
||||||
elif isinstance(img_path, list):
|
elif isinstance(images, list):
|
||||||
# Multiple images
|
# Multiple images
|
||||||
scores = []
|
scores = []
|
||||||
for one_img_path in img_path:
|
for one_images in images:
|
||||||
if isinstance(one_img_path, str):
|
if isinstance(one_images, str):
|
||||||
image = self.preprocess_val(Image.open(one_img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
elif isinstance(one_img_path, Image.Image):
|
elif isinstance(one_images, Image.Image):
|
||||||
image = self.preprocess_val(one_img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
|
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||||
else:
|
else:
|
||||||
raise TypeError("The type of parameter img_path is illegal.")
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
scores.append(self._calculate_score(image, prompt))
|
scores.append(self._calculate_score(image, prompt))
|
||||||
return scores
|
return scores
|
||||||
else:
|
else:
|
||||||
raise TypeError("The type of parameter img_path is illegal.")
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error in scoring images: {e}")
|
raise RuntimeError(f"Error in scoring images: {e}")
|
||||||
@@ -52,11 +52,11 @@ class MLP(torch.nn.Module):
|
|||||||
return self.layers(input)
|
return self.layers(input)
|
||||||
|
|
||||||
class ImageReward(torch.nn.Module):
|
class ImageReward(torch.nn.Module):
|
||||||
def __init__(self, med_config, device='cpu'):
|
def __init__(self, med_config, device='cpu', bert_model_path=""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
|
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
|
||||||
self.preprocess = _transform(224)
|
self.preprocess = _transform(224)
|
||||||
self.mlp = MLP(768)
|
self.mlp = MLP(768)
|
||||||
|
|
||||||
@@ -88,7 +88,7 @@ class ImageReward(torch.nn.Module):
|
|||||||
rewards = (rewards - self.mean) / self.std
|
rewards = (rewards - self.mean) / self.std
|
||||||
return rewards
|
return rewards
|
||||||
|
|
||||||
def score(self, prompt: str, images: Union[str, List[str], Image.Image, List[Image.Image]]) -> List[float]:
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
||||||
"""Score the images based on the prompt.
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -187,21 +187,18 @@ class ImageReward(torch.nn.Module):
|
|||||||
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
||||||
|
|
||||||
|
|
||||||
class ImageRewardScore:
|
class ImageRewardScore(torch.nn.Module):
|
||||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||||
"""Initialize the Selector with a processor and model.
|
super().__init__()
|
||||||
|
|
||||||
Args:
|
|
||||||
device (Union[str, torch.device]): The device to load the model on.
|
|
||||||
"""
|
|
||||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||||
model_path = path.get("imagereward")
|
model_path = path.get("imagereward")
|
||||||
med_config = path.get("med_config")
|
med_config = path.get("med_config")
|
||||||
state_dict = load_file(model_path)
|
state_dict = load_file(model_path)
|
||||||
self.model = ImageReward(device=self.device, med_config=med_config).to(self.device)
|
self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
|
||||||
self.model.load_state_dict(state_dict, strict=False)
|
self.model.load_state_dict(state_dict, strict=False)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||||
"""Score the images based on the prompt.
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
@@ -212,4 +209,4 @@ class ImageRewardScore:
|
|||||||
Returns:
|
Returns:
|
||||||
List[float]: List of scores for the images.
|
List[float]: List of scores for the images.
|
||||||
"""
|
"""
|
||||||
return self.model.score(prompt, images)
|
return self.model.score(images, prompt)
|
||||||
@@ -24,8 +24,9 @@ import gc
|
|||||||
import json
|
import json
|
||||||
from .config import MODEL_PATHS
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
class MPScore:
|
class MPScore(torch.nn.Module):
|
||||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
|
||||||
|
super().__init__()
|
||||||
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -35,7 +36,7 @@ class MPScore:
|
|||||||
processor_name_or_path = path.get("clip")
|
processor_name_or_path = path.get("clip")
|
||||||
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
||||||
self.model = clip_model.CLIPModel(processor_name_or_path)
|
self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True)
|
||||||
state_dict = load_file(path.get("mps"))
|
state_dict = load_file(path.get("mps"))
|
||||||
self.model.load_state_dict(state_dict, strict=False)
|
self.model.load_state_dict(state_dict, strict=False)
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
@@ -94,37 +95,35 @@ class MPScore:
|
|||||||
|
|
||||||
return image_score[0].cpu().numpy().item()
|
return image_score[0].cpu().numpy().item()
|
||||||
|
|
||||||
def score(self, img_path: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
@torch.no_grad()
|
||||||
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||||
"""Score the images based on the prompt.
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_path (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||||
prompt (str): The prompt text.
|
prompt (str): The prompt text.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[float]: List of reward scores for the images.
|
List[float]: List of reward scores for the images.
|
||||||
"""
|
"""
|
||||||
try:
|
if isinstance(images, (str, Image.Image)):
|
||||||
if isinstance(img_path, (str, Image.Image)):
|
# Single image
|
||||||
# Single image
|
if isinstance(images, str):
|
||||||
if isinstance(img_path, str):
|
image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device)
|
||||||
image = self.image_processor(Image.open(img_path), return_tensors="pt")["pixel_values"].to(self.device)
|
|
||||||
else:
|
|
||||||
image = self.image_processor(img_path, return_tensors="pt")["pixel_values"].to(self.device)
|
|
||||||
return [self._calculate_score(image, prompt)]
|
|
||||||
elif isinstance(img_path, list):
|
|
||||||
# Multiple images
|
|
||||||
scores = []
|
|
||||||
for one_img_path in img_path:
|
|
||||||
if isinstance(one_img_path, str):
|
|
||||||
image = self.image_processor(Image.open(one_img_path), return_tensors="pt")["pixel_values"].to(self.device)
|
|
||||||
elif isinstance(one_img_path, Image.Image):
|
|
||||||
image = self.image_processor(one_img_path, return_tensors="pt")["pixel_values"].to(self.device)
|
|
||||||
else:
|
|
||||||
raise TypeError("The type of parameter img_path is illegal.")
|
|
||||||
scores.append(self._calculate_score(image, prompt))
|
|
||||||
return scores
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("The type of parameter img_path is illegal.")
|
image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device)
|
||||||
except Exception as e:
|
return [self._calculate_score(image, prompt)]
|
||||||
raise RuntimeError(f"Error in scoring images: {e}")
|
elif isinstance(images, list):
|
||||||
|
# Multiple images
|
||||||
|
scores = []
|
||||||
|
for one_images in images:
|
||||||
|
if isinstance(one_images, str):
|
||||||
|
image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device)
|
||||||
|
elif isinstance(one_images, Image.Image):
|
||||||
|
image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device)
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
|
scores.append(self._calculate_score(image, prompt))
|
||||||
|
return scores
|
||||||
|
else:
|
||||||
|
raise TypeError("The type of parameter images is illegal.")
|
||||||
@@ -9,6 +9,6 @@ from .openai import load_openai_model, list_openai_models
|
|||||||
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
||||||
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
||||||
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
||||||
from .tokenizer import SimpleTokenizer, tokenize, decode
|
from .tokenizer import SimpleTokenizer
|
||||||
from .transform import image_transform, AugmentationCfg
|
from .transform import image_transform, AugmentationCfg
|
||||||
from .utils import freeze_batch_norm_2d
|
from .utils import freeze_batch_norm_2d
|
||||||
@@ -18,7 +18,7 @@ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
|||||||
from .openai import load_openai_model
|
from .openai import load_openai_model
|
||||||
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
|
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
|
||||||
from .transform import image_transform, AugmentationCfg
|
from .transform import image_transform, AugmentationCfg
|
||||||
from .tokenizer import HFTokenizer, tokenize
|
from .tokenizer import HFTokenizer, SimpleTokenizer
|
||||||
|
|
||||||
|
|
||||||
HF_HUB_PREFIX = 'hf-hub:'
|
HF_HUB_PREFIX = 'hf-hub:'
|
||||||
@@ -74,13 +74,13 @@ def get_model_config(model_name):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(model_name):
|
def get_tokenizer(model_name, open_clip_bpe_path=None):
|
||||||
if model_name.startswith(HF_HUB_PREFIX):
|
if model_name.startswith(HF_HUB_PREFIX):
|
||||||
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
|
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
|
||||||
else:
|
else:
|
||||||
config = get_model_config(model_name)
|
config = get_model_config(model_name)
|
||||||
tokenizer = HFTokenizer(
|
tokenizer = HFTokenizer(
|
||||||
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
|
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path)
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@@ -152,43 +152,37 @@ class SimpleTokenizer(object):
|
|||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
||||||
|
"""
|
||||||
|
Returns the tokenized representation of given input string(s)
|
||||||
|
|
||||||
_tokenizer = SimpleTokenizer()
|
Parameters
|
||||||
|
----------
|
||||||
|
texts : Union[str, List[str]]
|
||||||
|
An input string or a list of input strings to tokenize
|
||||||
|
context_length : int
|
||||||
|
The context length to use; all CLIP models use 77 as the context length
|
||||||
|
|
||||||
def decode(output_ids: torch.Tensor):
|
Returns
|
||||||
output_ids = output_ids.cpu().numpy()
|
-------
|
||||||
return _tokenizer.decode(output_ids)
|
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||||
|
"""
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
sot_token = self.encoder["<start_of_text>"]
|
||||||
"""
|
eot_token = self.encoder["<end_of_text>"]
|
||||||
Returns the tokenized representation of given input string(s)
|
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
||||||
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||||
|
|
||||||
Parameters
|
for i, tokens in enumerate(all_tokens):
|
||||||
----------
|
if len(tokens) > context_length:
|
||||||
texts : Union[str, List[str]]
|
tokens = tokens[:context_length] # Truncate
|
||||||
An input string or a list of input strings to tokenize
|
tokens[-1] = eot_token
|
||||||
context_length : int
|
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||||
The context length to use; all CLIP models use 77 as the context length
|
|
||||||
|
|
||||||
Returns
|
return result
|
||||||
-------
|
|
||||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
|
||||||
"""
|
|
||||||
if isinstance(texts, str):
|
|
||||||
texts = [texts]
|
|
||||||
|
|
||||||
sot_token = _tokenizer.encoder["<start_of_text>"]
|
|
||||||
eot_token = _tokenizer.encoder["<end_of_text>"]
|
|
||||||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
|
||||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
|
||||||
|
|
||||||
for i, tokens in enumerate(all_tokens):
|
|
||||||
if len(tokens) > context_length:
|
|
||||||
tokens = tokens[:context_length] # Truncate
|
|
||||||
tokens[-1] = eot_token
|
|
||||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class HFTokenizer:
|
class HFTokenizer:
|
||||||
@@ -5,8 +5,9 @@ from typing import List, Union
|
|||||||
import os
|
import os
|
||||||
from .config import MODEL_PATHS
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
class PickScore:
|
class PickScore(torch.nn.Module):
|
||||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||||
|
super().__init__()
|
||||||
"""Initialize the Selector with a processor and model.
|
"""Initialize the Selector with a processor and model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -53,6 +54,7 @@ class PickScore:
|
|||||||
|
|
||||||
return score.cpu().item()
|
return score.cpu().item()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]:
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]:
|
||||||
"""Score the images based on the prompt.
|
"""Score the images based on the prompt.
|
||||||
|
|
||||||
@@ -12,7 +12,7 @@ import torch
|
|||||||
|
|
||||||
from .cross_modeling import Cross_model
|
from .cross_modeling import Cross_model
|
||||||
|
|
||||||
import gc
|
import json, os
|
||||||
|
|
||||||
class XCLIPModel(HFCLIPModel):
|
class XCLIPModel(HFCLIPModel):
|
||||||
def __init__(self, config: CLIPConfig):
|
def __init__(self, config: CLIPConfig):
|
||||||
@@ -96,9 +96,15 @@ class ClipModelConfig(BaseModelConfig):
|
|||||||
|
|
||||||
|
|
||||||
class CLIPModel(nn.Module):
|
class CLIPModel(nn.Module):
|
||||||
def __init__(self, ckpt):
|
def __init__(self, ckpt, config_file=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = XCLIPModel.from_pretrained(ckpt)
|
if config_file is None:
|
||||||
|
self.model = XCLIPModel.from_pretrained(ckpt)
|
||||||
|
else:
|
||||||
|
with open(os.path.join(ckpt, "config.json"), "r", encoding="utf-8") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
config = CLIPConfig(**config)
|
||||||
|
self.model = XCLIPModel._from_config(config)
|
||||||
self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
|
self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
|
||||||
|
|
||||||
def get_text_features(self, *args, **kwargs):
|
def get_text_features(self, *args, **kwargs):
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
from .aesthetic import *
|
|
||||||
from .clip import *
|
|
||||||
from .config import *
|
|
||||||
from .hps import *
|
|
||||||
from .imagereward import *
|
|
||||||
from .mps import *
|
|
||||||
from .pickscore import *
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
# Image Quality Metric
|
|
||||||
|
|
||||||
The image quality assessment functionality has now been integrated into Diffsynth.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### Step 1: Download pretrained reward models
|
|
||||||
|
|
||||||
```
|
|
||||||
modelscope download --model 'DiffSynth-Studio/QualityMetric_reward_pretrained'
|
|
||||||
```
|
|
||||||
|
|
||||||
The file directory is shown below.
|
|
||||||
|
|
||||||
```
|
|
||||||
DiffSynth-Studio/
|
|
||||||
└── models/
|
|
||||||
└── QualityMetric/
|
|
||||||
├── HPS_v2/
|
|
||||||
│ ├── HPS_v2_compressed.safetensors
|
|
||||||
│ ├── HPS_v2.1_compressed.safetensors
|
|
||||||
└── ...
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 2: Test image quality metric
|
|
||||||
|
|
||||||
Prompt: "a painting of an ocean with clouds and birds, day time, low depth field effect"
|
|
||||||
|
|
||||||
|1.webp|2.webp|3.webp|4.webp|
|
|
||||||
|-|-|-|-|
|
|
||||||
|||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python testreward.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### Output:
|
|
||||||
|
|
||||||
```
|
|
||||||
ImageReward: [0.5811904668807983, 0.2745198607444763, -1.4158903360366821, -2.032487154006958]
|
|
||||||
Aesthetic [5.900862693786621, 5.776571273803711, 5.799864292144775, 5.05204963684082]
|
|
||||||
PickScore: [0.20737126469612122, 0.20443597435951233, 0.20660750567913055, 0.19426065683364868]
|
|
||||||
CLIPScore: [0.3894640803337097, 0.3544551134109497, 0.33861416578292847, 0.32878392934799194]
|
|
||||||
HPScorev2: [0.2672519087791443, 0.25495243072509766, 0.24888549745082855, 0.24302822351455688]
|
|
||||||
HPScorev21: [0.2321144938468933, 0.20233657956123352, 0.1978294551372528, 0.19230154156684875]
|
|
||||||
MPS_score: [10.921875, 10.71875, 10.578125, 9.25]
|
|
||||||
```
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 329 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 250 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 275 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 311 KiB |
@@ -1,80 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from diffsynth.extensions.QualityMetric.imagereward import ImageRewardScore
|
|
||||||
from diffsynth.extensions.QualityMetric.pickscore import PickScore
|
|
||||||
from diffsynth.extensions.QualityMetric.aesthetic import AestheticScore
|
|
||||||
from diffsynth.extensions.QualityMetric.clip import CLIPScore
|
|
||||||
from diffsynth.extensions.QualityMetric.hps import HPScore_v2
|
|
||||||
from diffsynth.extensions.QualityMetric.mps import MPScore
|
|
||||||
|
|
||||||
# download model from modelscope
|
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
|
||||||
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
project_root = os.path.abspath(os.path.join(current_dir, '../../'))
|
|
||||||
model_folder = os.path.join(project_root, 'models', 'QualityMetric')
|
|
||||||
|
|
||||||
# download HPS_v2 to your folder
|
|
||||||
# model_id = "DiffSynth-Studio/QualityMetric_reward_pretrained"
|
|
||||||
# downloaded_path = snapshot_download(
|
|
||||||
# model_id,
|
|
||||||
# cache_dir=os.path.join(model_folder, 'HPS_v2'),
|
|
||||||
# allow_patterns=["HPS_v2/*"],
|
|
||||||
# )
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
def get_model_path(model_folder, model_name):
|
|
||||||
return os.path.join(model_folder, model_name)
|
|
||||||
|
|
||||||
# your model path
|
|
||||||
model_path = {
|
|
||||||
"aesthetic_predictor": get_model_path(model_folder, "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
|
|
||||||
"open_clip": get_model_path(model_folder, "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
|
|
||||||
"hpsv2": get_model_path(model_folder, "HPS_v2/HPS_v2_compressed.safetensors"),
|
|
||||||
"hpsv2.1": get_model_path(model_folder, "HPS_v2/HPS_v2.1_compressed.safetensors"),
|
|
||||||
"imagereward": get_model_path(model_folder, "ImageReward/ImageReward.safetensors"),
|
|
||||||
"med_config": get_model_path(model_folder, "ImageReward/med_config.json"),
|
|
||||||
"clip": get_model_path(model_folder, "CLIP-ViT-H-14-laion2B-s32B-b79K"),
|
|
||||||
"clip-large": get_model_path(model_folder, "clip-vit-large-patch14"),
|
|
||||||
"mps": get_model_path(model_folder, "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
|
|
||||||
"pickscore": get_model_path(model_folder, "PickScore_v1")
|
|
||||||
}
|
|
||||||
|
|
||||||
# load reward models
|
|
||||||
mps_score = MPScore(device,path = model_path)
|
|
||||||
image_reward = ImageRewardScore(device, path = model_path)
|
|
||||||
aesthetic_score = AestheticScore(device, path = model_path)
|
|
||||||
pick_score = PickScore(device, path = model_path)
|
|
||||||
clip_score = CLIPScore(device, path = model_path)
|
|
||||||
hps_score = HPScore_v2(device, path = model_path, model_version = 'v2')
|
|
||||||
hps2_score = HPScore_v2(device, path = model_path, model_version = 'v21')
|
|
||||||
|
|
||||||
prompt = "a painting of an ocean with clouds and birds, day time, low depth field effect"
|
|
||||||
img_prefix = "images"
|
|
||||||
generations = [f"{pic_id}.webp" for pic_id in range(1, 5)]
|
|
||||||
|
|
||||||
img_list = [Image.open(os.path.join(img_prefix, img)) for img in generations]
|
|
||||||
#img_list = [os.path.join(img_prefix, img) for img in generations]
|
|
||||||
|
|
||||||
imre_scores = image_reward.score(img_list, prompt)
|
|
||||||
print("ImageReward:", imre_scores)
|
|
||||||
|
|
||||||
aes_scores = aesthetic_score.score(img_list)
|
|
||||||
print("Aesthetic", aes_scores)
|
|
||||||
|
|
||||||
p_scores = pick_score.score(img_list, prompt)
|
|
||||||
print("PickScore:", p_scores)
|
|
||||||
|
|
||||||
c_scores = clip_score.score(img_list, prompt)
|
|
||||||
print("CLIPScore:", c_scores)
|
|
||||||
|
|
||||||
h_scores = hps_score.score(img_list,prompt)
|
|
||||||
print("HPScorev2:", h_scores)
|
|
||||||
|
|
||||||
h2_scores = hps2_score.score(img_list,prompt)
|
|
||||||
print("HPScorev21:", h2_scores)
|
|
||||||
|
|
||||||
m_scores = mps_score.score(img_list, prompt)
|
|
||||||
print("MPS_score:", m_scores)
|
|
||||||
15
examples/image_quality_metric/README.md
Normal file
15
examples/image_quality_metric/README.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Image Quality Metric
|
||||||
|
|
||||||
|
The image quality assessment functionality has been integrated into Diffsynth. We support the following models:
|
||||||
|
|
||||||
|
* [ImageReward](https://github.com/THUDM/ImageReward)
|
||||||
|
* [Aesthetic](https://github.com/christophschuhmann/improved-aesthetic-predictor)
|
||||||
|
* [PickScore](https://github.com/yuvalkirstain/pickscore)
|
||||||
|
* [CLIP](https://github.com/openai/CLIP)
|
||||||
|
* [HPSv2](https://github.com/tgxs002/HPSv2)
|
||||||
|
* [HPSv2.1](https://github.com/tgxs002/HPSv2)
|
||||||
|
* [MPS](https://github.com/Kwai-Kolors/MPS)
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
See [`./image_quality_evaluation.py`](./image_quality_evaluation.py) for more details.
|
||||||
23
examples/image_quality_metric/image_quality_evaluation.py
Normal file
23
examples/image_quality_metric/image_quality_evaluation.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from diffsynth.extensions.ImageQualityMetric import download_preference_model, load_preference_model
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
# Download example image
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
allow_file_pattern="data/examples/ImageQualityMetric/image.jpg",
|
||||||
|
local_dir="./"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
prompt = "an orange cat"
|
||||||
|
image = Image.open("data\examples\ImageQualityMetric\image.jpg")
|
||||||
|
device = "cuda"
|
||||||
|
cache_dir = "./models"
|
||||||
|
|
||||||
|
# Run preference models
|
||||||
|
for model_name in ["ImageReward", "Aesthetic", "PickScore", "CLIP", "HPSv2", "HPSv2.1", "MPS"]:
|
||||||
|
path = download_preference_model(model_name, cache_dir=cache_dir)
|
||||||
|
preference_model = load_preference_model(model_name, device=device, path=path)
|
||||||
|
print(model_name, preference_model.score(image, prompt))
|
||||||
Reference in New Issue
Block a user