mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
add new quality metric
This commit is contained in:
@@ -13,8 +13,16 @@ from transformers import BertTokenizer
|
||||
from .vit import VisionTransformer, interpolate_pos_embed
|
||||
|
||||
|
||||
def default_bert():
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
||||
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||
return os.path.join(model_path, "bert-base-uncased")
|
||||
|
||||
bert_model_path = default_bert()
|
||||
|
||||
def init_tokenizer():
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
||||
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
||||
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
||||
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
||||
|
||||
@@ -50,31 +50,30 @@ class MLP(torch.nn.Module):
|
||||
|
||||
|
||||
class AestheticScore:
|
||||
def __init__(self, device: torch.device, model_path: str = MODEL_PATHS.get("aesthetic_predictor")):
|
||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||
"""Initialize the Selector with a model and processor.
|
||||
|
||||
Args:
|
||||
device (torch.device): The device to load the model on.
|
||||
model_path (str): Path to the model weights file.
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
self.aes_model_path = path.get("aesthetic_predictor")
|
||||
# Load the MLP model
|
||||
self.model = MLP(768)
|
||||
try:
|
||||
if model_path.endswith(".safetensors"):
|
||||
state_dict = load_file(model_path)
|
||||
if self.aes_model_path.endswith(".safetensors"):
|
||||
state_dict = load_file(self.aes_model_path)
|
||||
else:
|
||||
state_dict = torch.load(model_path)
|
||||
state_dict = torch.load(self.aes_model_path)
|
||||
self.model.load_state_dict(state_dict)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading model weights from {model_path}: {e}")
|
||||
raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
|
||||
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
|
||||
# Load the CLIP model and processor
|
||||
clip_model_name = MODEL_PATHS.get('clip-large')
|
||||
clip_model_name = path.get('clip-large')
|
||||
self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
|
||||
self.processor = AutoProcessor.from_pretrained(clip_model_name)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from .open_clip import create_model_and_transforms, get_tokenizer
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class CLIPScore:
|
||||
def __init__(self, device: torch.device):
|
||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||
"""Initialize the CLIPScore with a model and tokenizer.
|
||||
|
||||
Args:
|
||||
@@ -17,7 +17,7 @@ class CLIPScore:
|
||||
self.model, _, self.preprocess_val = create_model_and_transforms(
|
||||
"ViT-H-14",
|
||||
# "laion2B-s32B-b79K",
|
||||
pretrained=MODEL_PATHS.get("open_clip"),
|
||||
pretrained=path.get("open_clip"),
|
||||
precision="amp",
|
||||
device=device,
|
||||
jit=False,
|
||||
|
||||
@@ -2,11 +2,11 @@ import os
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
|
||||
quality_metric_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||
|
||||
|
||||
def get_model_path(model_name):
|
||||
return os.path.join(quality_metric_path, model_name)
|
||||
return os.path.join(model_path, model_name)
|
||||
|
||||
|
||||
MODEL_PATHS = {
|
||||
@@ -18,6 +18,6 @@ MODEL_PATHS = {
|
||||
"med_config": get_model_path("ImageReward/med_config.json"),
|
||||
"clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
|
||||
"clip-large": get_model_path("clip-vit-large-patch14"),
|
||||
"mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.pth"),
|
||||
"mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
|
||||
"pickscore": get_model_path("PickScore_v1")
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class HPScore_v2:
|
||||
def __init__(self, device: torch.device, model_version: str = "v2"):
|
||||
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
|
||||
"""Initialize the Selector with a model and tokenizer.
|
||||
|
||||
Args:
|
||||
@@ -17,9 +17,9 @@ class HPScore_v2:
|
||||
self.device = device
|
||||
|
||||
if model_version == "v2":
|
||||
safetensors_path = MODEL_PATHS.get("hpsv2")
|
||||
safetensors_path = path.get("hpsv2")
|
||||
elif model_version == "v21":
|
||||
safetensors_path = MODEL_PATHS.get("hpsv2.1")
|
||||
safetensors_path = path.get("hpsv2.1")
|
||||
else:
|
||||
raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
|
||||
|
||||
@@ -27,7 +27,7 @@ class HPScore_v2:
|
||||
model, _, self.preprocess_val = create_model_and_transforms(
|
||||
"ViT-H-14",
|
||||
# "laion2B-s32B-b79K",
|
||||
pretrained=MODEL_PATHS.get("open_clip"),
|
||||
pretrained=path.get("open_clip"),
|
||||
precision="amp",
|
||||
device=device,
|
||||
jit=False,
|
||||
|
||||
@@ -188,15 +188,15 @@ class ImageReward(torch.nn.Module):
|
||||
|
||||
|
||||
class ImageRewardScore:
|
||||
def __init__(self, device: Union[str, torch.device]):
|
||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||
"""Initialize the Selector with a processor and model.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
"""
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
model_path = MODEL_PATHS.get("imagereward")
|
||||
med_config = MODEL_PATHS.get("med_config")
|
||||
model_path = path.get("imagereward")
|
||||
med_config = path.get("med_config")
|
||||
state_dict = load_file(model_path)
|
||||
self.model = ImageReward(device=self.device, med_config=med_config).to(self.device)
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
@@ -4,10 +4,10 @@ from PIL import Image
|
||||
from io import BytesIO
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
|
||||
|
||||
from transformers import CLIPConfig
|
||||
from dataclasses import dataclass
|
||||
from transformers import CLIPModel as HFCLIPModel
|
||||
|
||||
from safetensors.torch import load_file
|
||||
from torch import nn, einsum
|
||||
|
||||
from .trainer.models.base_model import BaseModelConfig
|
||||
@@ -18,26 +18,27 @@ from typing import Any, Optional, Tuple, Union, List
|
||||
import torch
|
||||
|
||||
from .trainer.models.cross_modeling import Cross_model
|
||||
from .trainer.models import clip_model
|
||||
import torch.nn.functional as F
|
||||
|
||||
import gc
|
||||
import json
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class MPScore:
|
||||
def __init__(self, device: Union[str, torch.device], condition: str = 'overall'):
|
||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
|
||||
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
"""
|
||||
self.device = device
|
||||
processor_name_or_path = MODEL_PATHS.get("clip")
|
||||
processor_name_or_path = path.get("clip")
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
||||
|
||||
model_ckpt_path = MODEL_PATHS.get("mps")
|
||||
self.model = torch.load(model_ckpt_path).eval().to(device)
|
||||
self.model = clip_model.CLIPModel(processor_name_or_path)
|
||||
state_dict = load_file(path.get("mps"))
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
self.model.to(device)
|
||||
self.condition = condition
|
||||
|
||||
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
||||
|
||||
Binary file not shown.
@@ -1,22 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"quick_gelu": true,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": [
|
||||
3,
|
||||
4,
|
||||
23,
|
||||
3
|
||||
],
|
||||
"width": 64,
|
||||
"patch_size": null
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": [
|
||||
3,
|
||||
4,
|
||||
23,
|
||||
3
|
||||
],
|
||||
"width": 64,
|
||||
"patch_size": null
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"quick_gelu": true,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": [
|
||||
3,
|
||||
4,
|
||||
6,
|
||||
3
|
||||
],
|
||||
"width": 64,
|
||||
"patch_size": null
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": [
|
||||
3,
|
||||
4,
|
||||
6,
|
||||
3
|
||||
],
|
||||
"width": 64,
|
||||
"patch_size": null
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
{
|
||||
"embed_dim": 768,
|
||||
"vision_cfg": {
|
||||
"image_size": 384,
|
||||
"layers": [
|
||||
6,
|
||||
8,
|
||||
18,
|
||||
8
|
||||
],
|
||||
"width": 96,
|
||||
"patch_size": null
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 768,
|
||||
"heads": 12,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
{
|
||||
"embed_dim": 640,
|
||||
"vision_cfg": {
|
||||
"image_size": 288,
|
||||
"layers": [
|
||||
4,
|
||||
6,
|
||||
10,
|
||||
6
|
||||
],
|
||||
"width": 80,
|
||||
"patch_size": null
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 640,
|
||||
"heads": 10,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"vision_cfg": {
|
||||
"image_size": 448,
|
||||
"layers": [
|
||||
3,
|
||||
15,
|
||||
36,
|
||||
10
|
||||
],
|
||||
"width": 128,
|
||||
"patch_size": null
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 1024,
|
||||
"heads": 16,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 640,
|
||||
"vision_cfg": {
|
||||
"image_size": 240,
|
||||
"layers": 12,
|
||||
"width": 896,
|
||||
"patch_size": 16
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 640,
|
||||
"heads": 10,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 640,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 896,
|
||||
"patch_size": 16
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 640,
|
||||
"heads": 10,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 768,
|
||||
"patch_size": 16
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 640,
|
||||
"vision_cfg": {
|
||||
"image_size": 256,
|
||||
"layers": 12,
|
||||
"width": 896,
|
||||
"patch_size": 32
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 640,
|
||||
"heads": 10,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"quick_gelu": true,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 768,
|
||||
"patch_size": 32
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 768,
|
||||
"patch_size": 32
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 32,
|
||||
"width": 1280,
|
||||
"head_width": 80,
|
||||
"patch_size": 16
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 1024,
|
||||
"heads": 16,
|
||||
"layers": 24
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 768,
|
||||
"vision_cfg": {
|
||||
"image_size": 280,
|
||||
"layers": 24,
|
||||
"width": 1024,
|
||||
"patch_size": 14
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 768,
|
||||
"heads": 12,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 768,
|
||||
"vision_cfg": {
|
||||
"image_size": 336,
|
||||
"layers": 24,
|
||||
"width": 1024,
|
||||
"patch_size": 14
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 768,
|
||||
"heads": 12,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 768,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 24,
|
||||
"width": 1024,
|
||||
"patch_size": 14
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 768,
|
||||
"heads": 12,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 768,
|
||||
"vision_cfg": {
|
||||
"image_size": 320,
|
||||
"layers": 24,
|
||||
"width": 1024,
|
||||
"patch_size": 16
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 768,
|
||||
"heads": 12,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 768,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 24,
|
||||
"width": 1024,
|
||||
"patch_size": 16
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 768,
|
||||
"heads": 12,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
{
|
||||
"embed_dim": 384,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 512,
|
||||
"patch_size": 16,
|
||||
"ls_init_value": 1e-4
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 384,
|
||||
"heads": 6,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 512,
|
||||
"patch_size": 16
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 384,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 512,
|
||||
"patch_size": 32
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 384,
|
||||
"heads": 6,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 512,
|
||||
"patch_size": 32
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 256,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 384,
|
||||
"patch_size": 16
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 256,
|
||||
"heads": 4,
|
||||
"layers": 10
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 384,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 384,
|
||||
"patch_size": 16
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 384,
|
||||
"heads": 6,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 256,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 384,
|
||||
"patch_size": 32
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 256,
|
||||
"heads": 4,
|
||||
"layers": 10
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 384,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 384,
|
||||
"patch_size": 32
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 384,
|
||||
"heads": 6,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1280,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 48,
|
||||
"width": 1664,
|
||||
"head_width": 104,
|
||||
"mlp_ratio": 4.9231,
|
||||
"patch_size": 14
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 1280,
|
||||
"heads": 20,
|
||||
"layers": 32
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1280,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 56,
|
||||
"width": 1792,
|
||||
"head_width": 112,
|
||||
"mlp_ratio": 8.5715,
|
||||
"patch_size": 14
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 1280,
|
||||
"heads": 20,
|
||||
"layers": 36
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 40,
|
||||
"width": 1408,
|
||||
"head_width": 88,
|
||||
"mlp_ratio": 4.3637,
|
||||
"patch_size": 14
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 1024,
|
||||
"heads": 16,
|
||||
"layers": 24
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 768,
|
||||
"patch_size": 32,
|
||||
"attentional_pool": true,
|
||||
"attn_pooler_heads": 8,
|
||||
"output_tokens": true
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 76,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12,
|
||||
"embed_cls": true,
|
||||
"output_tokens": true
|
||||
},
|
||||
"multimodal_cfg": {
|
||||
"context_length": 76,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12,
|
||||
"attn_pooler_heads": 8
|
||||
},
|
||||
"custom_text": true
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
{
|
||||
"embed_dim": 768,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 24,
|
||||
"width": 1024,
|
||||
"patch_size": 14,
|
||||
"attentional_pool": true,
|
||||
"attn_pooler_heads": 8,
|
||||
"output_tokens": true
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 76,
|
||||
"vocab_size": 49408,
|
||||
"width": 768,
|
||||
"heads": 12,
|
||||
"layers": 12,
|
||||
"embed_cls": true,
|
||||
"output_tokens": true
|
||||
},
|
||||
"multimodal_cfg": {
|
||||
"context_length": 76,
|
||||
"vocab_size": 49408,
|
||||
"width": 768,
|
||||
"heads": 12,
|
||||
"layers": 12,
|
||||
"attn_pooler_heads": 12
|
||||
},
|
||||
"custom_text": true
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"multimodal_cfg": {
|
||||
"width": 768,
|
||||
"context_length": 76,
|
||||
"vocab_size": 64000,
|
||||
"mlp_ratio": 4,
|
||||
"layers": 12,
|
||||
"dim_head": 64,
|
||||
"heads": 12,
|
||||
"n_queries": 256,
|
||||
"attn_pooler_heads": 8
|
||||
},
|
||||
"vision_cfg": {
|
||||
"image_size": 288,
|
||||
"layers": 12,
|
||||
"width": 768,
|
||||
"patch_size": 18,
|
||||
"output_tokens": true
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 76,
|
||||
"vocab_size": 64000,
|
||||
"layers": 12,
|
||||
"heads": 12,
|
||||
"width": 768,
|
||||
"embed_cls": true,
|
||||
"output_tokens": true
|
||||
},
|
||||
"custom_text": true
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 768,
|
||||
"patch_size": 32,
|
||||
"output_tokens": true
|
||||
},
|
||||
"text_cfg": {
|
||||
"hf_model_name": "roberta-base",
|
||||
"hf_tokenizer_name": "roberta-base",
|
||||
"proj": "linear",
|
||||
"width": 768,
|
||||
"output_tokens": true
|
||||
},
|
||||
"multimodal_cfg": {
|
||||
"context_length": 76,
|
||||
"width": 768,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
},
|
||||
"custom_text": true
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "convnext_base",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "linear",
|
||||
"timm_drop": 0.0,
|
||||
"timm_drop_path": 0.1,
|
||||
"image_size": 224
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"embed_dim": 640,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "convnext_base",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "linear",
|
||||
"timm_drop": 0.0,
|
||||
"timm_drop_path": 0.1,
|
||||
"image_size": 256
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 640,
|
||||
"heads": 10,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"embed_dim": 640,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "convnext_base",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "linear",
|
||||
"timm_drop": 0.0,
|
||||
"timm_drop_path": 0.1,
|
||||
"image_size": 320
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 640,
|
||||
"heads": 10,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"embed_dim": 768,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "convnext_large",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "linear",
|
||||
"timm_drop": 0.0,
|
||||
"timm_drop_path": 0.1,
|
||||
"image_size": 224
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 768,
|
||||
"heads": 12,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"embed_dim": 768,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "convnext_large",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "mlp",
|
||||
"timm_drop": 0.0,
|
||||
"timm_drop_path": 0.1,
|
||||
"image_size": 256
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 768,
|
||||
"heads": 12,
|
||||
"layers": 16
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"embed_dim": 768,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "convnext_large",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "mlp",
|
||||
"timm_drop": 0.0,
|
||||
"timm_drop_path": 0.1,
|
||||
"image_size": 320
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 768,
|
||||
"heads": 12,
|
||||
"layers": 16
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "convnext_small",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "linear",
|
||||
"timm_drop": 0.0,
|
||||
"timm_drop_path": 0.1,
|
||||
"image_size": 224
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "convnext_tiny",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "linear",
|
||||
"timm_drop": 0.0,
|
||||
"timm_drop_path": 0.1,
|
||||
"image_size": 224
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "convnext_xlarge",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "linear",
|
||||
"timm_drop": 0.0,
|
||||
"timm_drop_path": 0.1,
|
||||
"image_size": 256
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 1024,
|
||||
"heads": 16,
|
||||
"layers": 20
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "convnext_xxlarge",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "linear",
|
||||
"timm_drop": 0.0,
|
||||
"timm_drop_path": 0.1,
|
||||
"image_size": 256
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 1024,
|
||||
"heads": 16,
|
||||
"layers": 24
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "convnext_xxlarge",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "linear",
|
||||
"timm_drop": 0.0,
|
||||
"timm_drop_path": 0.1,
|
||||
"image_size": 320
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 1024,
|
||||
"heads": 16,
|
||||
"layers": 24
|
||||
}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 768,
|
||||
"patch_size": 32
|
||||
},
|
||||
"text_cfg": {
|
||||
"hf_model_name": "google/mt5-base",
|
||||
"hf_tokenizer_name": "google/mt5-base",
|
||||
"proj": "mlp",
|
||||
"pooler_type": "mean_pooler"
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 32,
|
||||
"width": 1280,
|
||||
"head_width": 80,
|
||||
"patch_size": 14
|
||||
},
|
||||
"text_cfg": {
|
||||
"hf_model_name": "google/mt5-xl",
|
||||
"hf_tokenizer_name": "google/mt5-xl",
|
||||
"proj": "mlp",
|
||||
"pooler_type": "mean_pooler"
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"quick_gelu": true,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 768,
|
||||
"patch_size": 32
|
||||
},
|
||||
"text_cfg": {
|
||||
"hf_model_name": "roberta-base",
|
||||
"hf_tokenizer_name": "roberta-base",
|
||||
"proj": "mlp",
|
||||
"pooler_type": "mean_pooler"
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
{
|
||||
"embed_dim": 640,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "swin_base_patch4_window7_224",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "linear",
|
||||
"image_size": 224
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 640,
|
||||
"heads": 10,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "vit_medium_patch16_gap_256",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "linear",
|
||||
"image_size": 256
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"timm_model_name": "vit_relpos_medium_patch16_cls_224",
|
||||
"timm_model_pretrained": false,
|
||||
"timm_pool": "",
|
||||
"timm_proj": "linear",
|
||||
"image_size": 224
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"width": 512,
|
||||
"heads": 8,
|
||||
"layers": 12
|
||||
}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 12,
|
||||
"width": 768,
|
||||
"patch_size": 32
|
||||
},
|
||||
"text_cfg": {
|
||||
"hf_model_name": "xlm-roberta-base",
|
||||
"hf_tokenizer_name": "xlm-roberta-base",
|
||||
"proj": "mlp",
|
||||
"pooler_type": "mean_pooler"
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"embed_dim": 1024,
|
||||
"vision_cfg": {
|
||||
"image_size": 224,
|
||||
"layers": 32,
|
||||
"width": 1280,
|
||||
"head_width": 80,
|
||||
"patch_size": 14
|
||||
},
|
||||
"text_cfg": {
|
||||
"hf_model_name": "xlm-roberta-large",
|
||||
"hf_tokenizer_name": "xlm-roberta-large",
|
||||
"proj": "mlp",
|
||||
"pooler_type": "mean_pooler"
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,10 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
||||
quality_metric_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||
return os.path.join(quality_metric_path, "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
|
||||
@@ -6,15 +6,15 @@ import os
|
||||
from .config import MODEL_PATHS
|
||||
|
||||
class PickScore:
|
||||
def __init__(self, device: Union[str, torch.device]):
|
||||
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||
"""Initialize the Selector with a processor and model.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
"""
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
processor_name_or_path = MODEL_PATHS.get("clip")
|
||||
model_pretrained_name_or_path = MODEL_PATHS.get("pickscore")
|
||||
processor_name_or_path = path.get("clip")
|
||||
model_pretrained_name_or_path = path.get("pickscore")
|
||||
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
|
||||
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .base_model import *
|
||||
from .clip_model import *
|
||||
from .cross_modeling import *
|
||||
@@ -4,13 +4,13 @@ from transformers import AutoTokenizer
|
||||
|
||||
from torch import nn, einsum
|
||||
|
||||
from trainer.models.base_model import BaseModelConfig
|
||||
from .base_model import BaseModelConfig
|
||||
|
||||
from transformers import CLIPConfig
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from trainer.models.cross_modeling import Cross_model
|
||||
from .cross_modeling import Cross_model
|
||||
|
||||
import gc
|
||||
|
||||
@@ -91,7 +91,7 @@ class XCLIPModel(HFCLIPModel):
|
||||
|
||||
@dataclass
|
||||
class ClipModelConfig(BaseModelConfig):
|
||||
_target_: str = "trainer.models.clip_model.CLIPModel"
|
||||
_target_: str = "diffsynth.extensions.QualityMetric.trainer.models.clip_model.CLIPModel"
|
||||
pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"
|
||||
|
||||
|
||||
|
||||
@@ -8,17 +8,48 @@ from diffsynth.extensions.QualityMetric.clip import CLIPScore
|
||||
from diffsynth.extensions.QualityMetric.hps import HPScore_v2
|
||||
from diffsynth.extensions.QualityMetric.mps import MPScore
|
||||
|
||||
# download model from modelscope
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, '../../'))
|
||||
model_folder = os.path.join(project_root, 'models', 'QualityMetric')
|
||||
|
||||
# download HPS_v2 to your folder
|
||||
# model_id = "DiffSynth-Studio/QualityMetric_reward_pretrained"
|
||||
# downloaded_path = snapshot_download(
|
||||
# model_id,
|
||||
# cache_dir=os.path.join(model_folder, 'HPS_v2'),
|
||||
# allow_patterns=["HPS_v2/*"],
|
||||
# )
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def get_model_path(model_folder, model_name):
|
||||
return os.path.join(model_folder, model_name)
|
||||
|
||||
# your model path
|
||||
model_path = {
|
||||
"aesthetic_predictor": get_model_path(model_folder, "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
|
||||
"open_clip": get_model_path(model_folder, "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
|
||||
"hpsv2": get_model_path(model_folder, "HPS_v2/HPS_v2_compressed.safetensors"),
|
||||
"hpsv2.1": get_model_path(model_folder, "HPS_v2/HPS_v2.1_compressed.safetensors"),
|
||||
"imagereward": get_model_path(model_folder, "ImageReward/ImageReward.safetensors"),
|
||||
"med_config": get_model_path(model_folder, "ImageReward/med_config.json"),
|
||||
"clip": get_model_path(model_folder, "CLIP-ViT-H-14-laion2B-s32B-b79K"),
|
||||
"clip-large": get_model_path(model_folder, "clip-vit-large-patch14"),
|
||||
"mps": get_model_path(model_folder, "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
|
||||
"pickscore": get_model_path(model_folder, "PickScore_v1")
|
||||
}
|
||||
|
||||
# load reward models
|
||||
mps_score = MPScore(device)
|
||||
image_reward = ImageRewardScore(device)
|
||||
aesthetic_score = AestheticScore(device)
|
||||
pick_score = PickScore(device)
|
||||
clip_score = CLIPScore(device)
|
||||
hps_score = HPScore_v2(device, model_version = 'v2')
|
||||
hps2_score = HPScore_v2(device, model_version = 'v21')
|
||||
mps_score = MPScore(device,path = model_path)
|
||||
image_reward = ImageRewardScore(device, path = model_path)
|
||||
aesthetic_score = AestheticScore(device, path = model_path)
|
||||
pick_score = PickScore(device, path = model_path)
|
||||
clip_score = CLIPScore(device, path = model_path)
|
||||
hps_score = HPScore_v2(device, path = model_path, model_version = 'v2')
|
||||
hps2_score = HPScore_v2(device, path = model_path, model_version = 'v21')
|
||||
|
||||
prompt = "a painting of an ocean with clouds and birds, day time, low depth field effect"
|
||||
img_prefix = "images"
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig:
|
||||
pass
|
||||
@@ -1,140 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from transformers import CLIPModel as HFCLIPModel
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from torch import nn, einsum
|
||||
|
||||
from trainer.models.base_model import BaseModelConfig
|
||||
|
||||
from transformers import CLIPConfig
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from trainer.models.cross_modeling import Cross_model
|
||||
|
||||
import gc
|
||||
|
||||
class XCLIPModel(HFCLIPModel):
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__(config)
|
||||
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# pooled_output = text_outputs[1]
|
||||
# text_features = self.text_projection(pooled_output)
|
||||
last_hidden_state = text_outputs[0]
|
||||
text_features = self.text_projection(last_hidden_state)
|
||||
|
||||
pooled_output = text_outputs[1]
|
||||
text_features_EOS = self.text_projection(pooled_output)
|
||||
|
||||
|
||||
# del last_hidden_state, text_outputs
|
||||
# gc.collect()
|
||||
|
||||
return text_features, text_features_EOS
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# pooled_output = vision_outputs[1] # pooled_output
|
||||
# image_features = self.visual_projection(pooled_output)
|
||||
last_hidden_state = vision_outputs[0]
|
||||
image_features = self.visual_projection(last_hidden_state)
|
||||
|
||||
return image_features
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipModelConfig(BaseModelConfig):
|
||||
_target_: str = "trainer.models.clip_model.CLIPModel"
|
||||
pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"
|
||||
|
||||
|
||||
class CLIPModel(nn.Module):
|
||||
def __init__(self, ckpt):
|
||||
super().__init__()
|
||||
self.model = XCLIPModel.from_pretrained(ckpt)
|
||||
self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
|
||||
|
||||
def get_text_features(self, *args, **kwargs):
|
||||
return self.model.get_text_features(*args, **kwargs)
|
||||
|
||||
def get_image_features(self, *args, **kwargs):
|
||||
return self.model.get_image_features(*args, **kwargs)
|
||||
|
||||
def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
|
||||
outputs = ()
|
||||
|
||||
text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
|
||||
outputs += text_EOS,
|
||||
|
||||
image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
|
||||
condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024
|
||||
|
||||
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
||||
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
||||
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
||||
mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77
|
||||
|
||||
mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
|
||||
bc = int(image_f.shape[0]/2)
|
||||
|
||||
sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
|
||||
sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
|
||||
outputs += sim0[:,0,:],
|
||||
outputs += sim1[:,0,:],
|
||||
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def logit_scale(self):
|
||||
return self.model.logit_scale
|
||||
|
||||
def save(self, path):
|
||||
self.model.save_pretrained(path)
|
||||
|
||||
@@ -1,292 +0,0 @@
|
||||
import torch
|
||||
from torch import einsum, nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
# normalization
|
||||
# they use layernorm without bias, something that pytorch does not offer
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.register_buffer("bias", torch.zeros(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)
|
||||
|
||||
# residual
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return self.fn(x, *args, **kwargs) + x
|
||||
|
||||
|
||||
# rotary positional embedding
|
||||
# https://arxiv.org/abs/2104.09864
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
def forward(self, max_seq_len, *, device):
|
||||
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||
freqs = einsum("i , j -> i j", seq, self.inv_freq)
|
||||
return torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
||||
x1, x2 = x.unbind(dim=-2)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(pos, t):
|
||||
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
|
||||
|
||||
|
||||
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
|
||||
# https://arxiv.org/abs/2002.05202
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return F.silu(gate) * x
|
||||
|
||||
|
||||
# parallel attention and feedforward with residual
|
||||
# discovered by Wang et al + EleutherAI from GPT-J fame
|
||||
|
||||
class ParallelTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm(dim)
|
||||
|
||||
attn_inner_dim = dim_head * heads
|
||||
ff_inner_dim = dim * ff_mult
|
||||
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head**-0.5
|
||||
self.rotary_emb = RotaryEmbedding(dim_head)
|
||||
|
||||
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
|
||||
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
|
||||
|
||||
self.ff_out = nn.Sequential(
|
||||
SwiGLU(),
|
||||
nn.Linear(ff_inner_dim, dim, bias=False)
|
||||
)
|
||||
|
||||
self.register_buffer("pos_emb", None, persistent=False)
|
||||
|
||||
|
||||
def get_rotary_embedding(self, n, device):
|
||||
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
|
||||
return self.pos_emb[:n]
|
||||
|
||||
pos_emb = self.rotary_emb(n, device=device)
|
||||
self.register_buffer("pos_emb", pos_emb, persistent=False)
|
||||
return pos_emb
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
n, i, j - sequence length (base sequence length, source, target)
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
n, device, h = x.shape[1], x.device, self.heads
|
||||
|
||||
# pre layernorm
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# attention queries, keys, values, and feedforward inner
|
||||
|
||||
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
|
||||
|
||||
# split heads
|
||||
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
||||
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
||||
# https://arxiv.org/abs/1911.02150
|
||||
|
||||
q = rearrange(q, "b n (h d) -> b h n d", h=h)
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
positions = self.get_rotary_embedding(n, device)
|
||||
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum("b h i d, b j d -> b h i j", q, k)
|
||||
|
||||
|
||||
# extra attention mask - for masking out attention from text CLS token to padding
|
||||
|
||||
if exists(attn_mask):
|
||||
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
|
||||
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
|
||||
|
||||
# attention
|
||||
|
||||
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
return self.attn_out(out) + self.ff_out(ff)
|
||||
|
||||
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
context_dim=None,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
parallel_ff=False,
|
||||
ff_mult=4,
|
||||
norm_context=False
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = heads * dim_head
|
||||
context_dim = default(context_dim, dim)
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
# whether to have parallel feedforward
|
||||
|
||||
ff_inner_dim = ff_mult * dim
|
||||
|
||||
self.ff = nn.Sequential(
|
||||
nn.Linear(dim, ff_inner_dim * 2, bias=False),
|
||||
SwiGLU(),
|
||||
nn.Linear(ff_inner_dim, dim, bias=False)
|
||||
) if parallel_ff else None
|
||||
|
||||
def forward(self, x, context, mask):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
n, i, j - sequence length (base sequence length, source, target)
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
# pre-layernorm, for queries and context
|
||||
|
||||
x = self.norm(x)
|
||||
context = self.context_norm(context)
|
||||
|
||||
# get queries
|
||||
|
||||
q = self.to_q(x)
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# get key / values
|
||||
|
||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||
|
||||
# query / key similarity
|
||||
|
||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||
|
||||
# attention
|
||||
mask = mask.unsqueeze(1).repeat(1,self.heads,1,1)
|
||||
sim = sim + mask # context mask
|
||||
sim = sim - sim.amax(dim=-1, keepdim=True)
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
# merge and combine heads
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
|
||||
# add parallel feedforward (for multimodal layers)
|
||||
|
||||
if exists(self.ff):
|
||||
out = out + self.ff(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Cross_model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=512,
|
||||
layer_num=4,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
ff_mult=4
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
|
||||
for ind in range(layer_num):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
|
||||
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
|
||||
]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query_tokens,
|
||||
context_tokens,
|
||||
mask
|
||||
):
|
||||
|
||||
for cross_attn, self_attn_ff in self.layers:
|
||||
query_tokens = cross_attn(query_tokens, context_tokens,mask)
|
||||
query_tokens = self_attn_ff(query_tokens)
|
||||
|
||||
return query_tokens
|
||||
Reference in New Issue
Block a user