add new quality metric

This commit is contained in:
YunhongLu-ZJU
2025-02-17 14:42:20 +08:00
parent 77d0f4d297
commit 991ba162bd
69 changed files with 88 additions and 1461 deletions

View File

@@ -8,17 +8,48 @@ from diffsynth.extensions.QualityMetric.clip import CLIPScore
from diffsynth.extensions.QualityMetric.hps import HPScore_v2
from diffsynth.extensions.QualityMetric.mps import MPScore
# download model from modelscope
from modelscope.hub.snapshot_download import snapshot_download
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, '../../'))
model_folder = os.path.join(project_root, 'models', 'QualityMetric')
# download HPS_v2 to your folder
# model_id = "DiffSynth-Studio/QualityMetric_reward_pretrained"
# downloaded_path = snapshot_download(
# model_id,
# cache_dir=os.path.join(model_folder, 'HPS_v2'),
# allow_patterns=["HPS_v2/*"],
# )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_model_path(model_folder, model_name):
return os.path.join(model_folder, model_name)
# your model path
model_path = {
"aesthetic_predictor": get_model_path(model_folder, "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
"open_clip": get_model_path(model_folder, "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
"hpsv2": get_model_path(model_folder, "HPS_v2/HPS_v2_compressed.safetensors"),
"hpsv2.1": get_model_path(model_folder, "HPS_v2/HPS_v2.1_compressed.safetensors"),
"imagereward": get_model_path(model_folder, "ImageReward/ImageReward.safetensors"),
"med_config": get_model_path(model_folder, "ImageReward/med_config.json"),
"clip": get_model_path(model_folder, "CLIP-ViT-H-14-laion2B-s32B-b79K"),
"clip-large": get_model_path(model_folder, "clip-vit-large-patch14"),
"mps": get_model_path(model_folder, "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
"pickscore": get_model_path(model_folder, "PickScore_v1")
}
# load reward models
mps_score = MPScore(device)
image_reward = ImageRewardScore(device)
aesthetic_score = AestheticScore(device)
pick_score = PickScore(device)
clip_score = CLIPScore(device)
hps_score = HPScore_v2(device, model_version = 'v2')
hps2_score = HPScore_v2(device, model_version = 'v21')
mps_score = MPScore(device,path = model_path)
image_reward = ImageRewardScore(device, path = model_path)
aesthetic_score = AestheticScore(device, path = model_path)
pick_score = PickScore(device, path = model_path)
clip_score = CLIPScore(device, path = model_path)
hps_score = HPScore_v2(device, path = model_path, model_version = 'v2')
hps2_score = HPScore_v2(device, path = model_path, model_version = 'v21')
prompt = "a painting of an ocean with clouds and birds, day time, low depth field effect"
img_prefix = "images"