mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 10:48:11 +00:00
update preference models
This commit is contained in:
@@ -19,9 +19,8 @@ def default_bert():
|
||||
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||
return os.path.join(model_path, "bert-base-uncased")
|
||||
|
||||
bert_model_path = default_bert()
|
||||
|
||||
def init_tokenizer():
|
||||
def init_tokenizer(bert_model_path):
|
||||
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
||||
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
||||
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
||||
@@ -20,6 +20,7 @@ class BLIP_Pretrain(nn.Module):
|
||||
embed_dim = 256,
|
||||
queue_size = 57600,
|
||||
momentum = 0.995,
|
||||
bert_model_path = ""
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -31,7 +32,7 @@ class BLIP_Pretrain(nn.Module):
|
||||
|
||||
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
||||
|
||||
self.tokenizer = init_tokenizer()
|
||||
self.tokenizer = init_tokenizer(bert_model_path)
|
||||
encoder_config = BertConfig.from_json_file(med_config)
|
||||
encoder_config.encoder_width = vision_width
|
||||
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
||||
@@ -14,7 +14,7 @@ from timm.models.registry import register_model
|
||||
from timm.models.layers import trunc_normal_, DropPath
|
||||
from timm.models.helpers import named_apply, adapt_input_conv
|
||||
|
||||
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
||||
# from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||
@@ -96,9 +96,9 @@ class Block(nn.Module):
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
if use_grad_checkpointing:
|
||||
self.attn = checkpoint_wrapper(self.attn)
|
||||
self.mlp = checkpoint_wrapper(self.mlp)
|
||||
# if use_grad_checkpointing:
|
||||
# self.attn = checkpoint_wrapper(self.attn)
|
||||
# self.mlp = checkpoint_wrapper(self.mlp)
|
||||
|
||||
def forward(self, x, register_hook=False):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
||||
148
diffsynth/extensions/ImageQualityMetric/__init__.py
Normal file
148
diffsynth/extensions/ImageQualityMetric/__init__.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from modelscope import snapshot_download
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
import os
|
||||
from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore
|
||||
from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore
|
||||
from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore
|
||||
from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore
|
||||
from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2
|
||||
from diffsynth.extensions.ImageQualityMetric.mps import MPScore
|
||||
|
||||
|
||||
preference_model_id: TypeAlias = Literal[
|
||||
"ImageReward",
|
||||
"Aesthetic",
|
||||
"PickScore",
|
||||
"CLIP",
|
||||
"HPSv2",
|
||||
"HPSv2.1",
|
||||
"MPS",
|
||||
]
|
||||
model_dict = {
|
||||
"ImageReward": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"ImageReward/ImageReward.safetensors",
|
||||
"ImageReward/med_config.json",
|
||||
"bert-base-uncased/config.json",
|
||||
"bert-base-uncased/model.safetensors",
|
||||
"bert-base-uncased/tokenizer.json",
|
||||
"bert-base-uncased/tokenizer_config.json",
|
||||
"bert-base-uncased/vocab.txt",
|
||||
],
|
||||
"load_path": {
|
||||
"imagereward": "ImageReward/ImageReward.safetensors",
|
||||
"med_config": "ImageReward/med_config.json",
|
||||
"bert_model_path": "bert-base-uncased",
|
||||
},
|
||||
"model_class": ImageRewardScore
|
||||
},
|
||||
"Aesthetic": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
||||
"clip-vit-large-patch14/config.json",
|
||||
"clip-vit-large-patch14/merges.txt",
|
||||
"clip-vit-large-patch14/model.safetensors",
|
||||
"clip-vit-large-patch14/preprocessor_config.json",
|
||||
"clip-vit-large-patch14/special_tokens_map.json",
|
||||
"clip-vit-large-patch14/tokenizer.json",
|
||||
"clip-vit-large-patch14/tokenizer_config.json",
|
||||
"clip-vit-large-patch14/vocab.json",
|
||||
],
|
||||
"load_path": {
|
||||
"aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
||||
"clip-large": "clip-vit-large-patch14",
|
||||
},
|
||||
"model_class": AestheticScore
|
||||
},
|
||||
"PickScore": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"PickScore_v1/*",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
||||
],
|
||||
"load_path": {
|
||||
"pickscore": "PickScore_v1",
|
||||
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
||||
},
|
||||
"model_class": PickScore
|
||||
},
|
||||
"CLIP": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
||||
"bpe_simple_vocab_16e6.txt.gz",
|
||||
],
|
||||
"load_path": {
|
||||
"open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
||||
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||
},
|
||||
"model_class": CLIPScore
|
||||
},
|
||||
"HPSv2": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"HPS_v2/HPS_v2_compressed.safetensors",
|
||||
"bpe_simple_vocab_16e6.txt.gz",
|
||||
],
|
||||
"load_path": {
|
||||
"hpsv2": "HPS_v2/HPS_v2_compressed.safetensors",
|
||||
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||
},
|
||||
"model_class": HPScore_v2,
|
||||
"extra_kwargs": {"model_version": "v2"}
|
||||
},
|
||||
"HPSv2.1": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"HPS_v2/HPS_v2.1_compressed.safetensors",
|
||||
"bpe_simple_vocab_16e6.txt.gz",
|
||||
],
|
||||
"load_path": {
|
||||
"hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors",
|
||||
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
||||
},
|
||||
"model_class": HPScore_v2,
|
||||
"extra_kwargs": {"model_version": "v21"}
|
||||
},
|
||||
"MPS": {
|
||||
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
||||
"allow_file_pattern": [
|
||||
"MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
||||
],
|
||||
"load_path": {
|
||||
"mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
||||
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
||||
},
|
||||
"model_class": MPScore
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def download_preference_model(model_name: preference_model_id, cache_dir="models"):
|
||||
metadata = model_dict[model_name]
|
||||
snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir)
|
||||
load_path = metadata["load_path"]
|
||||
load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()}
|
||||
return load_path
|
||||
|
||||
|
||||
def load_preference_model(model_name: preference_model_id, device = "cuda", path = None):
|
||||
model_class = model_dict[model_name]["model_class"]
|
||||
extra_kwargs = model_dict[model_name].get("extra_kwargs", {})
|
||||
preference_model = model_class(device=device, path=path, **extra_kwargs)
|
||||
return preference_model
|
||||
@@ -49,13 +49,9 @@ class MLP(torch.nn.Module):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
|
||||
class AestheticScore:
|
||||
class AestheticScore(torch.nn.Module):
|
||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||
"""Initialize the Selector with a model and processor.
|
||||
|
||||
Args:
|
||||
device (torch.device): The device to load the model on.
|
||||
"""
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.aes_model_path = path.get("aesthetic_predictor")
|
||||
# Load the MLP model
|
||||
@@ -96,7 +92,8 @@ class AestheticScore:
|
||||
|
||||
return score
|
||||
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]]) -> List[float]:
|
||||
@torch.no_grad()
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
||||
"""Score the images based on their aesthetic quality.
|
||||
|
||||
Args:
|
||||
@@ -4,8 +4,9 @@ import torch
|
||||
from .open_clip import create_model_and_transforms, get_tokenizer
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class CLIPScore:
|
||||
class CLIPScore(torch.nn.Module):
|
||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||
super().__init__()
|
||||
"""Initialize the CLIPScore with a model and tokenizer.
|
||||
|
||||
Args:
|
||||
@@ -36,7 +37,7 @@ class CLIPScore:
|
||||
)
|
||||
|
||||
# Initialize tokenizer
|
||||
self.tokenizer = get_tokenizer("ViT-H-14")
|
||||
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
||||
self.model = self.model.to(device)
|
||||
self.model.eval()
|
||||
|
||||
@@ -62,37 +63,35 @@ class CLIPScore:
|
||||
|
||||
return clip_score[0].item()
|
||||
|
||||
def score(self, img_path: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||
@torch.no_grad()
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||
"""Score the images based on the prompt.
|
||||
|
||||
Args:
|
||||
img_path (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
prompt (str): The prompt text.
|
||||
|
||||
Returns:
|
||||
List[float]: List of CLIP scores for the images.
|
||||
"""
|
||||
try:
|
||||
if isinstance(img_path, (str, Image.Image)):
|
||||
# Single image
|
||||
if isinstance(img_path, str):
|
||||
image = self.preprocess_val(Image.open(img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
else:
|
||||
image = self.preprocess_val(img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
return [self._calculate_score(image, prompt)]
|
||||
elif isinstance(img_path, list):
|
||||
# Multiple images
|
||||
scores = []
|
||||
for one_img_path in img_path:
|
||||
if isinstance(one_img_path, str):
|
||||
image = self.preprocess_val(Image.open(one_img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
elif isinstance(one_img_path, Image.Image):
|
||||
image = self.preprocess_val(one_img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
else:
|
||||
raise TypeError("The type of parameter img_path is illegal.")
|
||||
scores.append(self._calculate_score(image, prompt))
|
||||
return scores
|
||||
if isinstance(images, (str, Image.Image)):
|
||||
# Single image
|
||||
if isinstance(images, str):
|
||||
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
else:
|
||||
raise TypeError("The type of parameter img_path is illegal.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in scoring images: {e}")
|
||||
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
return [self._calculate_score(image, prompt)]
|
||||
elif isinstance(images, list):
|
||||
# Multiple images
|
||||
scores = []
|
||||
for one_images in images:
|
||||
if isinstance(one_images, str):
|
||||
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
elif isinstance(one_images, Image.Image):
|
||||
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
scores.append(self._calculate_score(image, prompt))
|
||||
return scores
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
@@ -6,8 +6,9 @@ from safetensors.torch import load_file
|
||||
import os
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class HPScore_v2:
|
||||
class HPScore_v2(torch.nn.Module):
|
||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
|
||||
super().__init__()
|
||||
"""Initialize the Selector with a model and tokenizer.
|
||||
|
||||
Args:
|
||||
@@ -53,7 +54,7 @@ class HPScore_v2:
|
||||
raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
|
||||
|
||||
# Initialize tokenizer and model
|
||||
self.tokenizer = get_tokenizer("ViT-H-14")
|
||||
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
self.model = model
|
||||
@@ -80,37 +81,38 @@ class HPScore_v2:
|
||||
|
||||
return hps_score[0].item()
|
||||
|
||||
def score(self, img_path: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||
@torch.no_grad()
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||
"""Score the images based on the prompt.
|
||||
|
||||
Args:
|
||||
img_path (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
prompt (str): The prompt text.
|
||||
|
||||
Returns:
|
||||
List[float]: List of HPS scores for the images.
|
||||
"""
|
||||
try:
|
||||
if isinstance(img_path, (str, Image.Image)):
|
||||
if isinstance(images, (str, Image.Image)):
|
||||
# Single image
|
||||
if isinstance(img_path, str):
|
||||
image = self.preprocess_val(Image.open(img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
if isinstance(images, str):
|
||||
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
else:
|
||||
image = self.preprocess_val(img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
return [self._calculate_score(image, prompt)]
|
||||
elif isinstance(img_path, list):
|
||||
elif isinstance(images, list):
|
||||
# Multiple images
|
||||
scores = []
|
||||
for one_img_path in img_path:
|
||||
if isinstance(one_img_path, str):
|
||||
image = self.preprocess_val(Image.open(one_img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
elif isinstance(one_img_path, Image.Image):
|
||||
image = self.preprocess_val(one_img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
for one_images in images:
|
||||
if isinstance(one_images, str):
|
||||
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
elif isinstance(one_images, Image.Image):
|
||||
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
||||
else:
|
||||
raise TypeError("The type of parameter img_path is illegal.")
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
scores.append(self._calculate_score(image, prompt))
|
||||
return scores
|
||||
else:
|
||||
raise TypeError("The type of parameter img_path is illegal.")
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in scoring images: {e}")
|
||||
@@ -52,11 +52,11 @@ class MLP(torch.nn.Module):
|
||||
return self.layers(input)
|
||||
|
||||
class ImageReward(torch.nn.Module):
|
||||
def __init__(self, med_config, device='cpu'):
|
||||
def __init__(self, med_config, device='cpu', bert_model_path=""):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
|
||||
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
|
||||
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
|
||||
self.preprocess = _transform(224)
|
||||
self.mlp = MLP(768)
|
||||
|
||||
@@ -88,7 +88,7 @@ class ImageReward(torch.nn.Module):
|
||||
rewards = (rewards - self.mean) / self.std
|
||||
return rewards
|
||||
|
||||
def score(self, prompt: str, images: Union[str, List[str], Image.Image, List[Image.Image]]) -> List[float]:
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
||||
"""Score the images based on the prompt.
|
||||
|
||||
Args:
|
||||
@@ -187,21 +187,18 @@ class ImageReward(torch.nn.Module):
|
||||
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
||||
|
||||
|
||||
class ImageRewardScore:
|
||||
class ImageRewardScore(torch.nn.Module):
|
||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||
"""Initialize the Selector with a processor and model.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
"""
|
||||
super().__init__()
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
model_path = path.get("imagereward")
|
||||
med_config = path.get("med_config")
|
||||
state_dict = load_file(model_path)
|
||||
self.model = ImageReward(device=self.device, med_config=med_config).to(self.device)
|
||||
self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||
"""Score the images based on the prompt.
|
||||
|
||||
@@ -212,4 +209,4 @@ class ImageRewardScore:
|
||||
Returns:
|
||||
List[float]: List of scores for the images.
|
||||
"""
|
||||
return self.model.score(prompt, images)
|
||||
return self.model.score(images, prompt)
|
||||
@@ -24,8 +24,9 @@ import gc
|
||||
import json
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class MPScore:
|
||||
class MPScore(torch.nn.Module):
|
||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
|
||||
super().__init__()
|
||||
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
||||
|
||||
Args:
|
||||
@@ -35,7 +36,7 @@ class MPScore:
|
||||
processor_name_or_path = path.get("clip")
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
||||
self.model = clip_model.CLIPModel(processor_name_or_path)
|
||||
self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True)
|
||||
state_dict = load_file(path.get("mps"))
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
self.model.to(device)
|
||||
@@ -94,37 +95,35 @@ class MPScore:
|
||||
|
||||
return image_score[0].cpu().numpy().item()
|
||||
|
||||
def score(self, img_path: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||
@torch.no_grad()
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
||||
"""Score the images based on the prompt.
|
||||
|
||||
Args:
|
||||
img_path (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
||||
prompt (str): The prompt text.
|
||||
|
||||
Returns:
|
||||
List[float]: List of reward scores for the images.
|
||||
"""
|
||||
try:
|
||||
if isinstance(img_path, (str, Image.Image)):
|
||||
# Single image
|
||||
if isinstance(img_path, str):
|
||||
image = self.image_processor(Image.open(img_path), return_tensors="pt")["pixel_values"].to(self.device)
|
||||
else:
|
||||
image = self.image_processor(img_path, return_tensors="pt")["pixel_values"].to(self.device)
|
||||
return [self._calculate_score(image, prompt)]
|
||||
elif isinstance(img_path, list):
|
||||
# Multiple images
|
||||
scores = []
|
||||
for one_img_path in img_path:
|
||||
if isinstance(one_img_path, str):
|
||||
image = self.image_processor(Image.open(one_img_path), return_tensors="pt")["pixel_values"].to(self.device)
|
||||
elif isinstance(one_img_path, Image.Image):
|
||||
image = self.image_processor(one_img_path, return_tensors="pt")["pixel_values"].to(self.device)
|
||||
else:
|
||||
raise TypeError("The type of parameter img_path is illegal.")
|
||||
scores.append(self._calculate_score(image, prompt))
|
||||
return scores
|
||||
if isinstance(images, (str, Image.Image)):
|
||||
# Single image
|
||||
if isinstance(images, str):
|
||||
image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device)
|
||||
else:
|
||||
raise TypeError("The type of parameter img_path is illegal.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in scoring images: {e}")
|
||||
image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device)
|
||||
return [self._calculate_score(image, prompt)]
|
||||
elif isinstance(images, list):
|
||||
# Multiple images
|
||||
scores = []
|
||||
for one_images in images:
|
||||
if isinstance(one_images, str):
|
||||
image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device)
|
||||
elif isinstance(one_images, Image.Image):
|
||||
image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device)
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
scores.append(self._calculate_score(image, prompt))
|
||||
return scores
|
||||
else:
|
||||
raise TypeError("The type of parameter images is illegal.")
|
||||
@@ -9,6 +9,6 @@ from .openai import load_openai_model, list_openai_models
|
||||
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
||||
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
||||
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
||||
from .tokenizer import SimpleTokenizer, tokenize, decode
|
||||
from .tokenizer import SimpleTokenizer
|
||||
from .transform import image_transform, AugmentationCfg
|
||||
from .utils import freeze_batch_norm_2d
|
||||
@@ -18,7 +18,7 @@ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
||||
from .openai import load_openai_model
|
||||
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
|
||||
from .transform import image_transform, AugmentationCfg
|
||||
from .tokenizer import HFTokenizer, tokenize
|
||||
from .tokenizer import HFTokenizer, SimpleTokenizer
|
||||
|
||||
|
||||
HF_HUB_PREFIX = 'hf-hub:'
|
||||
@@ -74,13 +74,13 @@ def get_model_config(model_name):
|
||||
return None
|
||||
|
||||
|
||||
def get_tokenizer(model_name):
|
||||
def get_tokenizer(model_name, open_clip_bpe_path=None):
|
||||
if model_name.startswith(HF_HUB_PREFIX):
|
||||
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
|
||||
else:
|
||||
config = get_model_config(model_name)
|
||||
tokenizer = HFTokenizer(
|
||||
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
|
||||
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path)
|
||||
return tokenizer
|
||||
|
||||
|
||||
@@ -152,43 +152,37 @@ class SimpleTokenizer(object):
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||
return text
|
||||
|
||||
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
|
||||
_tokenizer = SimpleTokenizer()
|
||||
Parameters
|
||||
----------
|
||||
texts : Union[str, List[str]]
|
||||
An input string or a list of input strings to tokenize
|
||||
context_length : int
|
||||
The context length to use; all CLIP models use 77 as the context length
|
||||
|
||||
def decode(output_ids: torch.Tensor):
|
||||
output_ids = output_ids.cpu().numpy()
|
||||
return _tokenizer.decode(output_ids)
|
||||
Returns
|
||||
-------
|
||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
sot_token = self.encoder["<start_of_text>"]
|
||||
eot_token = self.encoder["<end_of_text>"]
|
||||
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts : Union[str, List[str]]
|
||||
An input string or a list of input strings to tokenize
|
||||
context_length : int
|
||||
The context length to use; all CLIP models use 77 as the context length
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
tokens = tokens[:context_length] # Truncate
|
||||
tokens[-1] = eot_token
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
Returns
|
||||
-------
|
||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
return result
|
||||
|
||||
sot_token = _tokenizer.encoder["<start_of_text>"]
|
||||
eot_token = _tokenizer.encoder["<end_of_text>"]
|
||||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
tokens = tokens[:context_length] # Truncate
|
||||
tokens[-1] = eot_token
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class HFTokenizer:
|
||||
@@ -5,8 +5,9 @@ from typing import List, Union
|
||||
import os
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class PickScore:
|
||||
class PickScore(torch.nn.Module):
|
||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||
super().__init__()
|
||||
"""Initialize the Selector with a processor and model.
|
||||
|
||||
Args:
|
||||
@@ -53,6 +54,7 @@ class PickScore:
|
||||
|
||||
return score.cpu().item()
|
||||
|
||||
@torch.no_grad()
|
||||
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]:
|
||||
"""Score the images based on the prompt.
|
||||
|
||||
@@ -12,7 +12,7 @@ import torch
|
||||
|
||||
from .cross_modeling import Cross_model
|
||||
|
||||
import gc
|
||||
import json, os
|
||||
|
||||
class XCLIPModel(HFCLIPModel):
|
||||
def __init__(self, config: CLIPConfig):
|
||||
@@ -96,9 +96,15 @@ class ClipModelConfig(BaseModelConfig):
|
||||
|
||||
|
||||
class CLIPModel(nn.Module):
|
||||
def __init__(self, ckpt):
|
||||
def __init__(self, ckpt, config_file=False):
|
||||
super().__init__()
|
||||
self.model = XCLIPModel.from_pretrained(ckpt)
|
||||
if config_file is None:
|
||||
self.model = XCLIPModel.from_pretrained(ckpt)
|
||||
else:
|
||||
with open(os.path.join(ckpt, "config.json"), "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
config = CLIPConfig(**config)
|
||||
self.model = XCLIPModel._from_config(config)
|
||||
self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
|
||||
|
||||
def get_text_features(self, *args, **kwargs):
|
||||
@@ -1,7 +0,0 @@
|
||||
from .aesthetic import *
|
||||
from .clip import *
|
||||
from .config import *
|
||||
from .hps import *
|
||||
from .imagereward import *
|
||||
from .mps import *
|
||||
from .pickscore import *
|
||||
@@ -1,49 +0,0 @@
|
||||
# Image Quality Metric
|
||||
|
||||
The image quality assessment functionality has now been integrated into Diffsynth.
|
||||
|
||||
## Usage
|
||||
|
||||
### Step 1: Download pretrained reward models
|
||||
|
||||
```
|
||||
modelscope download --model 'DiffSynth-Studio/QualityMetric_reward_pretrained'
|
||||
```
|
||||
|
||||
The file directory is shown below.
|
||||
|
||||
```
|
||||
DiffSynth-Studio/
|
||||
└── models/
|
||||
└── QualityMetric/
|
||||
├── HPS_v2/
|
||||
│ ├── HPS_v2_compressed.safetensors
|
||||
│ ├── HPS_v2.1_compressed.safetensors
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Step 2: Test image quality metric
|
||||
|
||||
Prompt: "a painting of an ocean with clouds and birds, day time, low depth field effect"
|
||||
|
||||
|1.webp|2.webp|3.webp|4.webp|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
|
||||
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0 python testreward.py
|
||||
```
|
||||
|
||||
### Output:
|
||||
|
||||
```
|
||||
ImageReward: [0.5811904668807983, 0.2745198607444763, -1.4158903360366821, -2.032487154006958]
|
||||
Aesthetic [5.900862693786621, 5.776571273803711, 5.799864292144775, 5.05204963684082]
|
||||
PickScore: [0.20737126469612122, 0.20443597435951233, 0.20660750567913055, 0.19426065683364868]
|
||||
CLIPScore: [0.3894640803337097, 0.3544551134109497, 0.33861416578292847, 0.32878392934799194]
|
||||
HPScorev2: [0.2672519087791443, 0.25495243072509766, 0.24888549745082855, 0.24302822351455688]
|
||||
HPScorev21: [0.2321144938468933, 0.20233657956123352, 0.1978294551372528, 0.19230154156684875]
|
||||
MPS_score: [10.921875, 10.71875, 10.578125, 9.25]
|
||||
```
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 329 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 250 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 275 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 311 KiB |
@@ -1,80 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.extensions.QualityMetric.imagereward import ImageRewardScore
|
||||
from diffsynth.extensions.QualityMetric.pickscore import PickScore
|
||||
from diffsynth.extensions.QualityMetric.aesthetic import AestheticScore
|
||||
from diffsynth.extensions.QualityMetric.clip import CLIPScore
|
||||
from diffsynth.extensions.QualityMetric.hps import HPScore_v2
|
||||
from diffsynth.extensions.QualityMetric.mps import MPScore
|
||||
|
||||
# download model from modelscope
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, '../../'))
|
||||
model_folder = os.path.join(project_root, 'models', 'QualityMetric')
|
||||
|
||||
# download HPS_v2 to your folder
|
||||
# model_id = "DiffSynth-Studio/QualityMetric_reward_pretrained"
|
||||
# downloaded_path = snapshot_download(
|
||||
# model_id,
|
||||
# cache_dir=os.path.join(model_folder, 'HPS_v2'),
|
||||
# allow_patterns=["HPS_v2/*"],
|
||||
# )
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def get_model_path(model_folder, model_name):
|
||||
return os.path.join(model_folder, model_name)
|
||||
|
||||
# your model path
|
||||
model_path = {
|
||||
"aesthetic_predictor": get_model_path(model_folder, "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
|
||||
"open_clip": get_model_path(model_folder, "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
|
||||
"hpsv2": get_model_path(model_folder, "HPS_v2/HPS_v2_compressed.safetensors"),
|
||||
"hpsv2.1": get_model_path(model_folder, "HPS_v2/HPS_v2.1_compressed.safetensors"),
|
||||
"imagereward": get_model_path(model_folder, "ImageReward/ImageReward.safetensors"),
|
||||
"med_config": get_model_path(model_folder, "ImageReward/med_config.json"),
|
||||
"clip": get_model_path(model_folder, "CLIP-ViT-H-14-laion2B-s32B-b79K"),
|
||||
"clip-large": get_model_path(model_folder, "clip-vit-large-patch14"),
|
||||
"mps": get_model_path(model_folder, "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
|
||||
"pickscore": get_model_path(model_folder, "PickScore_v1")
|
||||
}
|
||||
|
||||
# load reward models
|
||||
mps_score = MPScore(device,path = model_path)
|
||||
image_reward = ImageRewardScore(device, path = model_path)
|
||||
aesthetic_score = AestheticScore(device, path = model_path)
|
||||
pick_score = PickScore(device, path = model_path)
|
||||
clip_score = CLIPScore(device, path = model_path)
|
||||
hps_score = HPScore_v2(device, path = model_path, model_version = 'v2')
|
||||
hps2_score = HPScore_v2(device, path = model_path, model_version = 'v21')
|
||||
|
||||
prompt = "a painting of an ocean with clouds and birds, day time, low depth field effect"
|
||||
img_prefix = "images"
|
||||
generations = [f"{pic_id}.webp" for pic_id in range(1, 5)]
|
||||
|
||||
img_list = [Image.open(os.path.join(img_prefix, img)) for img in generations]
|
||||
#img_list = [os.path.join(img_prefix, img) for img in generations]
|
||||
|
||||
imre_scores = image_reward.score(img_list, prompt)
|
||||
print("ImageReward:", imre_scores)
|
||||
|
||||
aes_scores = aesthetic_score.score(img_list)
|
||||
print("Aesthetic", aes_scores)
|
||||
|
||||
p_scores = pick_score.score(img_list, prompt)
|
||||
print("PickScore:", p_scores)
|
||||
|
||||
c_scores = clip_score.score(img_list, prompt)
|
||||
print("CLIPScore:", c_scores)
|
||||
|
||||
h_scores = hps_score.score(img_list,prompt)
|
||||
print("HPScorev2:", h_scores)
|
||||
|
||||
h2_scores = hps2_score.score(img_list,prompt)
|
||||
print("HPScorev21:", h2_scores)
|
||||
|
||||
m_scores = mps_score.score(img_list, prompt)
|
||||
print("MPS_score:", m_scores)
|
||||
15
examples/image_quality_metric/README.md
Normal file
15
examples/image_quality_metric/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# Image Quality Metric
|
||||
|
||||
The image quality assessment functionality has been integrated into Diffsynth. We support the following models:
|
||||
|
||||
* [ImageReward](https://github.com/THUDM/ImageReward)
|
||||
* [Aesthetic](https://github.com/christophschuhmann/improved-aesthetic-predictor)
|
||||
* [PickScore](https://github.com/yuvalkirstain/pickscore)
|
||||
* [CLIP](https://github.com/openai/CLIP)
|
||||
* [HPSv2](https://github.com/tgxs002/HPSv2)
|
||||
* [HPSv2.1](https://github.com/tgxs002/HPSv2)
|
||||
* [MPS](https://github.com/Kwai-Kolors/MPS)
|
||||
|
||||
## Usage
|
||||
|
||||
See [`./image_quality_evaluation.py`](./image_quality_evaluation.py) for more details.
|
||||
23
examples/image_quality_metric/image_quality_evaluation.py
Normal file
23
examples/image_quality_metric/image_quality_evaluation.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from diffsynth.extensions.ImageQualityMetric import download_preference_model, load_preference_model
|
||||
from modelscope import dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Download example image
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
allow_file_pattern="data/examples/ImageQualityMetric/image.jpg",
|
||||
local_dir="./"
|
||||
)
|
||||
|
||||
# Parameters
|
||||
prompt = "an orange cat"
|
||||
image = Image.open("data\examples\ImageQualityMetric\image.jpg")
|
||||
device = "cuda"
|
||||
cache_dir = "./models"
|
||||
|
||||
# Run preference models
|
||||
for model_name in ["ImageReward", "Aesthetic", "PickScore", "CLIP", "HPSv2", "HPSv2.1", "MPS"]:
|
||||
path = download_preference_model(model_name, cache_dir=cache_dir)
|
||||
preference_model = load_preference_model(model_name, device=device, path=path)
|
||||
print(model_name, preference_model.score(image, prompt))
|
||||
Reference in New Issue
Block a user