add quality metric

This commit is contained in:
YunhongLu-ZJU
2025-02-14 13:59:56 +08:00
parent acda7d891a
commit a834371d50
3 changed files with 14 additions and 14 deletions

View File

@@ -1,11 +1,15 @@
import os
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
quality_metric_path = os.path.join(project_root, 'models', 'QualityMetric')
# 打印生成的路径
print(quality_metric_path)
def get_model_path(model_name):
return os.path.join(CURRENT_DIR, MODEL_FOLDER, model_name)
return os.path.join(quality_metric_path, model_name)
MODEL_FOLDER = "reward_pretrained"
MODEL_PATHS = {
"aesthetic_predictor": get_model_path("aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),