mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
add new quality metric
This commit is contained in:
@@ -13,8 +13,16 @@ from transformers import BertTokenizer
|
|||||||
from .vit import VisionTransformer, interpolate_pos_embed
|
from .vit import VisionTransformer, interpolate_pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def default_bert():
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
||||||
|
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||||
|
return os.path.join(model_path, "bert-base-uncased")
|
||||||
|
|
||||||
|
bert_model_path = default_bert()
|
||||||
|
|
||||||
def init_tokenizer():
|
def init_tokenizer():
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
||||||
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
||||||
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
||||||
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
||||||
|
|||||||
@@ -50,31 +50,30 @@ class MLP(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class AestheticScore:
|
class AestheticScore:
|
||||||
def __init__(self, device: torch.device, model_path: str = MODEL_PATHS.get("aesthetic_predictor")):
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||||
"""Initialize the Selector with a model and processor.
|
"""Initialize the Selector with a model and processor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device (torch.device): The device to load the model on.
|
device (torch.device): The device to load the model on.
|
||||||
model_path (str): Path to the model weights file.
|
|
||||||
"""
|
"""
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.aes_model_path = path.get("aesthetic_predictor")
|
||||||
# Load the MLP model
|
# Load the MLP model
|
||||||
self.model = MLP(768)
|
self.model = MLP(768)
|
||||||
try:
|
try:
|
||||||
if model_path.endswith(".safetensors"):
|
if self.aes_model_path.endswith(".safetensors"):
|
||||||
state_dict = load_file(model_path)
|
state_dict = load_file(self.aes_model_path)
|
||||||
else:
|
else:
|
||||||
state_dict = torch.load(model_path)
|
state_dict = torch.load(self.aes_model_path)
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error loading model weights from {model_path}: {e}")
|
raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
|
||||||
|
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Load the CLIP model and processor
|
# Load the CLIP model and processor
|
||||||
clip_model_name = MODEL_PATHS.get('clip-large')
|
clip_model_name = path.get('clip-large')
|
||||||
self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
|
self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
|
||||||
self.processor = AutoProcessor.from_pretrained(clip_model_name)
|
self.processor = AutoProcessor.from_pretrained(clip_model_name)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from .open_clip import create_model_and_transforms, get_tokenizer
|
|||||||
from .config import MODEL_PATHS
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
class CLIPScore:
|
class CLIPScore:
|
||||||
def __init__(self, device: torch.device):
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
||||||
"""Initialize the CLIPScore with a model and tokenizer.
|
"""Initialize the CLIPScore with a model and tokenizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -17,7 +17,7 @@ class CLIPScore:
|
|||||||
self.model, _, self.preprocess_val = create_model_and_transforms(
|
self.model, _, self.preprocess_val = create_model_and_transforms(
|
||||||
"ViT-H-14",
|
"ViT-H-14",
|
||||||
# "laion2B-s32B-b79K",
|
# "laion2B-s32B-b79K",
|
||||||
pretrained=MODEL_PATHS.get("open_clip"),
|
pretrained=path.get("open_clip"),
|
||||||
precision="amp",
|
precision="amp",
|
||||||
device=device,
|
device=device,
|
||||||
jit=False,
|
jit=False,
|
||||||
|
|||||||
@@ -2,11 +2,11 @@ import os
|
|||||||
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
|
project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
|
||||||
quality_metric_path = os.path.join(project_root, 'models', 'QualityMetric')
|
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model_name):
|
def get_model_path(model_name):
|
||||||
return os.path.join(quality_metric_path, model_name)
|
return os.path.join(model_path, model_name)
|
||||||
|
|
||||||
|
|
||||||
MODEL_PATHS = {
|
MODEL_PATHS = {
|
||||||
@@ -18,6 +18,6 @@ MODEL_PATHS = {
|
|||||||
"med_config": get_model_path("ImageReward/med_config.json"),
|
"med_config": get_model_path("ImageReward/med_config.json"),
|
||||||
"clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
|
"clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
|
||||||
"clip-large": get_model_path("clip-vit-large-patch14"),
|
"clip-large": get_model_path("clip-vit-large-patch14"),
|
||||||
"mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.pth"),
|
"mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
|
||||||
"pickscore": get_model_path("PickScore_v1")
|
"pickscore": get_model_path("PickScore_v1")
|
||||||
}
|
}
|
||||||
@@ -7,7 +7,7 @@ import os
|
|||||||
from .config import MODEL_PATHS
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
class HPScore_v2:
|
class HPScore_v2:
|
||||||
def __init__(self, device: torch.device, model_version: str = "v2"):
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
|
||||||
"""Initialize the Selector with a model and tokenizer.
|
"""Initialize the Selector with a model and tokenizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -17,9 +17,9 @@ class HPScore_v2:
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
if model_version == "v2":
|
if model_version == "v2":
|
||||||
safetensors_path = MODEL_PATHS.get("hpsv2")
|
safetensors_path = path.get("hpsv2")
|
||||||
elif model_version == "v21":
|
elif model_version == "v21":
|
||||||
safetensors_path = MODEL_PATHS.get("hpsv2.1")
|
safetensors_path = path.get("hpsv2.1")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
|
raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ class HPScore_v2:
|
|||||||
model, _, self.preprocess_val = create_model_and_transforms(
|
model, _, self.preprocess_val = create_model_and_transforms(
|
||||||
"ViT-H-14",
|
"ViT-H-14",
|
||||||
# "laion2B-s32B-b79K",
|
# "laion2B-s32B-b79K",
|
||||||
pretrained=MODEL_PATHS.get("open_clip"),
|
pretrained=path.get("open_clip"),
|
||||||
precision="amp",
|
precision="amp",
|
||||||
device=device,
|
device=device,
|
||||||
jit=False,
|
jit=False,
|
||||||
|
|||||||
@@ -188,15 +188,15 @@ class ImageReward(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ImageRewardScore:
|
class ImageRewardScore:
|
||||||
def __init__(self, device: Union[str, torch.device]):
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||||
"""Initialize the Selector with a processor and model.
|
"""Initialize the Selector with a processor and model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device (Union[str, torch.device]): The device to load the model on.
|
device (Union[str, torch.device]): The device to load the model on.
|
||||||
"""
|
"""
|
||||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||||
model_path = MODEL_PATHS.get("imagereward")
|
model_path = path.get("imagereward")
|
||||||
med_config = MODEL_PATHS.get("med_config")
|
med_config = path.get("med_config")
|
||||||
state_dict = load_file(model_path)
|
state_dict = load_file(model_path)
|
||||||
self.model = ImageReward(device=self.device, med_config=med_config).to(self.device)
|
self.model = ImageReward(device=self.device, med_config=med_config).to(self.device)
|
||||||
self.model.load_state_dict(state_dict, strict=False)
|
self.model.load_state_dict(state_dict, strict=False)
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ from PIL import Image
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
|
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
|
||||||
|
from transformers import CLIPConfig
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from transformers import CLIPModel as HFCLIPModel
|
from transformers import CLIPModel as HFCLIPModel
|
||||||
|
from safetensors.torch import load_file
|
||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
|
|
||||||
from .trainer.models.base_model import BaseModelConfig
|
from .trainer.models.base_model import BaseModelConfig
|
||||||
@@ -18,26 +18,27 @@ from typing import Any, Optional, Tuple, Union, List
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .trainer.models.cross_modeling import Cross_model
|
from .trainer.models.cross_modeling import Cross_model
|
||||||
|
from .trainer.models import clip_model
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
from .config import MODEL_PATHS
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
class MPScore:
|
class MPScore:
|
||||||
def __init__(self, device: Union[str, torch.device], condition: str = 'overall'):
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
|
||||||
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device (Union[str, torch.device]): The device to load the model on.
|
device (Union[str, torch.device]): The device to load the model on.
|
||||||
"""
|
"""
|
||||||
self.device = device
|
self.device = device
|
||||||
processor_name_or_path = MODEL_PATHS.get("clip")
|
processor_name_or_path = path.get("clip")
|
||||||
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
||||||
|
self.model = clip_model.CLIPModel(processor_name_or_path)
|
||||||
model_ckpt_path = MODEL_PATHS.get("mps")
|
state_dict = load_file(path.get("mps"))
|
||||||
self.model = torch.load(model_ckpt_path).eval().to(device)
|
self.model.load_state_dict(state_dict, strict=False)
|
||||||
|
self.model.to(device)
|
||||||
self.condition = condition
|
self.condition = condition
|
||||||
|
|
||||||
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
||||||
|
|||||||
Binary file not shown.
@@ -1,22 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"quick_gelu": true,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": [
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
23,
|
|
||||||
3
|
|
||||||
],
|
|
||||||
"width": 64,
|
|
||||||
"patch_size": null
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": [
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
23,
|
|
||||||
3
|
|
||||||
],
|
|
||||||
"width": 64,
|
|
||||||
"patch_size": null
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"quick_gelu": true,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": [
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
6,
|
|
||||||
3
|
|
||||||
],
|
|
||||||
"width": 64,
|
|
||||||
"patch_size": null
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": [
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
6,
|
|
||||||
3
|
|
||||||
],
|
|
||||||
"width": 64,
|
|
||||||
"patch_size": null
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 768,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 384,
|
|
||||||
"layers": [
|
|
||||||
6,
|
|
||||||
8,
|
|
||||||
18,
|
|
||||||
8
|
|
||||||
],
|
|
||||||
"width": 96,
|
|
||||||
"patch_size": null
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 768,
|
|
||||||
"heads": 12,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 640,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 288,
|
|
||||||
"layers": [
|
|
||||||
4,
|
|
||||||
6,
|
|
||||||
10,
|
|
||||||
6
|
|
||||||
],
|
|
||||||
"width": 80,
|
|
||||||
"patch_size": null
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 640,
|
|
||||||
"heads": 10,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 448,
|
|
||||||
"layers": [
|
|
||||||
3,
|
|
||||||
15,
|
|
||||||
36,
|
|
||||||
10
|
|
||||||
],
|
|
||||||
"width": 128,
|
|
||||||
"patch_size": null
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 1024,
|
|
||||||
"heads": 16,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 640,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 240,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 896,
|
|
||||||
"patch_size": 16
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 640,
|
|
||||||
"heads": 10,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 640,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 896,
|
|
||||||
"patch_size": 16
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 640,
|
|
||||||
"heads": 10,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 768,
|
|
||||||
"patch_size": 16
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 640,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 256,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 896,
|
|
||||||
"patch_size": 32
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 640,
|
|
||||||
"heads": 10,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"quick_gelu": true,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 768,
|
|
||||||
"patch_size": 32
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 768,
|
|
||||||
"patch_size": 32
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 32,
|
|
||||||
"width": 1280,
|
|
||||||
"head_width": 80,
|
|
||||||
"patch_size": 16
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 1024,
|
|
||||||
"heads": 16,
|
|
||||||
"layers": 24
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 768,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 280,
|
|
||||||
"layers": 24,
|
|
||||||
"width": 1024,
|
|
||||||
"patch_size": 14
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 768,
|
|
||||||
"heads": 12,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 768,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 336,
|
|
||||||
"layers": 24,
|
|
||||||
"width": 1024,
|
|
||||||
"patch_size": 14
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 768,
|
|
||||||
"heads": 12,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 768,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 24,
|
|
||||||
"width": 1024,
|
|
||||||
"patch_size": 14
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 768,
|
|
||||||
"heads": 12,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 768,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 320,
|
|
||||||
"layers": 24,
|
|
||||||
"width": 1024,
|
|
||||||
"patch_size": 16
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 768,
|
|
||||||
"heads": 12,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 768,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 24,
|
|
||||||
"width": 1024,
|
|
||||||
"patch_size": 16
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 768,
|
|
||||||
"heads": 12,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 384,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 512,
|
|
||||||
"patch_size": 16,
|
|
||||||
"ls_init_value": 1e-4
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 384,
|
|
||||||
"heads": 6,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 512,
|
|
||||||
"patch_size": 16
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 384,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 512,
|
|
||||||
"patch_size": 32
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 384,
|
|
||||||
"heads": 6,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 512,
|
|
||||||
"patch_size": 32
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 256,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 384,
|
|
||||||
"patch_size": 16
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 256,
|
|
||||||
"heads": 4,
|
|
||||||
"layers": 10
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 384,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 384,
|
|
||||||
"patch_size": 16
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 384,
|
|
||||||
"heads": 6,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 256,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 384,
|
|
||||||
"patch_size": 32
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 256,
|
|
||||||
"heads": 4,
|
|
||||||
"layers": 10
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 384,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 384,
|
|
||||||
"patch_size": 32
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 384,
|
|
||||||
"heads": 6,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1280,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 48,
|
|
||||||
"width": 1664,
|
|
||||||
"head_width": 104,
|
|
||||||
"mlp_ratio": 4.9231,
|
|
||||||
"patch_size": 14
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 1280,
|
|
||||||
"heads": 20,
|
|
||||||
"layers": 32
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1280,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 56,
|
|
||||||
"width": 1792,
|
|
||||||
"head_width": 112,
|
|
||||||
"mlp_ratio": 8.5715,
|
|
||||||
"patch_size": 14
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 1280,
|
|
||||||
"heads": 20,
|
|
||||||
"layers": 36
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 40,
|
|
||||||
"width": 1408,
|
|
||||||
"head_width": 88,
|
|
||||||
"mlp_ratio": 4.3637,
|
|
||||||
"patch_size": 14
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 1024,
|
|
||||||
"heads": 16,
|
|
||||||
"layers": 24
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 768,
|
|
||||||
"patch_size": 32,
|
|
||||||
"attentional_pool": true,
|
|
||||||
"attn_pooler_heads": 8,
|
|
||||||
"output_tokens": true
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 76,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12,
|
|
||||||
"embed_cls": true,
|
|
||||||
"output_tokens": true
|
|
||||||
},
|
|
||||||
"multimodal_cfg": {
|
|
||||||
"context_length": 76,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12,
|
|
||||||
"attn_pooler_heads": 8
|
|
||||||
},
|
|
||||||
"custom_text": true
|
|
||||||
}
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 768,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 24,
|
|
||||||
"width": 1024,
|
|
||||||
"patch_size": 14,
|
|
||||||
"attentional_pool": true,
|
|
||||||
"attn_pooler_heads": 8,
|
|
||||||
"output_tokens": true
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 76,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 768,
|
|
||||||
"heads": 12,
|
|
||||||
"layers": 12,
|
|
||||||
"embed_cls": true,
|
|
||||||
"output_tokens": true
|
|
||||||
},
|
|
||||||
"multimodal_cfg": {
|
|
||||||
"context_length": 76,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 768,
|
|
||||||
"heads": 12,
|
|
||||||
"layers": 12,
|
|
||||||
"attn_pooler_heads": 12
|
|
||||||
},
|
|
||||||
"custom_text": true
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"multimodal_cfg": {
|
|
||||||
"width": 768,
|
|
||||||
"context_length": 76,
|
|
||||||
"vocab_size": 64000,
|
|
||||||
"mlp_ratio": 4,
|
|
||||||
"layers": 12,
|
|
||||||
"dim_head": 64,
|
|
||||||
"heads": 12,
|
|
||||||
"n_queries": 256,
|
|
||||||
"attn_pooler_heads": 8
|
|
||||||
},
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 288,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 768,
|
|
||||||
"patch_size": 18,
|
|
||||||
"output_tokens": true
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 76,
|
|
||||||
"vocab_size": 64000,
|
|
||||||
"layers": 12,
|
|
||||||
"heads": 12,
|
|
||||||
"width": 768,
|
|
||||||
"embed_cls": true,
|
|
||||||
"output_tokens": true
|
|
||||||
},
|
|
||||||
"custom_text": true
|
|
||||||
}
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 768,
|
|
||||||
"patch_size": 32,
|
|
||||||
"output_tokens": true
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"hf_model_name": "roberta-base",
|
|
||||||
"hf_tokenizer_name": "roberta-base",
|
|
||||||
"proj": "linear",
|
|
||||||
"width": 768,
|
|
||||||
"output_tokens": true
|
|
||||||
},
|
|
||||||
"multimodal_cfg": {
|
|
||||||
"context_length": 76,
|
|
||||||
"width": 768,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
},
|
|
||||||
"custom_text": true
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "convnext_base",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "linear",
|
|
||||||
"timm_drop": 0.0,
|
|
||||||
"timm_drop_path": 0.1,
|
|
||||||
"image_size": 224
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 640,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "convnext_base",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "linear",
|
|
||||||
"timm_drop": 0.0,
|
|
||||||
"timm_drop_path": 0.1,
|
|
||||||
"image_size": 256
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 640,
|
|
||||||
"heads": 10,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 640,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "convnext_base",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "linear",
|
|
||||||
"timm_drop": 0.0,
|
|
||||||
"timm_drop_path": 0.1,
|
|
||||||
"image_size": 320
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 640,
|
|
||||||
"heads": 10,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 768,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "convnext_large",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "linear",
|
|
||||||
"timm_drop": 0.0,
|
|
||||||
"timm_drop_path": 0.1,
|
|
||||||
"image_size": 224
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 768,
|
|
||||||
"heads": 12,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 768,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "convnext_large",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "mlp",
|
|
||||||
"timm_drop": 0.0,
|
|
||||||
"timm_drop_path": 0.1,
|
|
||||||
"image_size": 256
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 768,
|
|
||||||
"heads": 12,
|
|
||||||
"layers": 16
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 768,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "convnext_large",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "mlp",
|
|
||||||
"timm_drop": 0.0,
|
|
||||||
"timm_drop_path": 0.1,
|
|
||||||
"image_size": 320
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 768,
|
|
||||||
"heads": 12,
|
|
||||||
"layers": 16
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "convnext_small",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "linear",
|
|
||||||
"timm_drop": 0.0,
|
|
||||||
"timm_drop_path": 0.1,
|
|
||||||
"image_size": 224
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "convnext_tiny",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "linear",
|
|
||||||
"timm_drop": 0.0,
|
|
||||||
"timm_drop_path": 0.1,
|
|
||||||
"image_size": 224
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "convnext_xlarge",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "linear",
|
|
||||||
"timm_drop": 0.0,
|
|
||||||
"timm_drop_path": 0.1,
|
|
||||||
"image_size": 256
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 1024,
|
|
||||||
"heads": 16,
|
|
||||||
"layers": 20
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "convnext_xxlarge",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "linear",
|
|
||||||
"timm_drop": 0.0,
|
|
||||||
"timm_drop_path": 0.1,
|
|
||||||
"image_size": 256
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 1024,
|
|
||||||
"heads": 16,
|
|
||||||
"layers": 24
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "convnext_xxlarge",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "linear",
|
|
||||||
"timm_drop": 0.0,
|
|
||||||
"timm_drop_path": 0.1,
|
|
||||||
"image_size": 320
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 1024,
|
|
||||||
"heads": 16,
|
|
||||||
"layers": 24
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 768,
|
|
||||||
"patch_size": 32
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"hf_model_name": "google/mt5-base",
|
|
||||||
"hf_tokenizer_name": "google/mt5-base",
|
|
||||||
"proj": "mlp",
|
|
||||||
"pooler_type": "mean_pooler"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 32,
|
|
||||||
"width": 1280,
|
|
||||||
"head_width": 80,
|
|
||||||
"patch_size": 14
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"hf_model_name": "google/mt5-xl",
|
|
||||||
"hf_tokenizer_name": "google/mt5-xl",
|
|
||||||
"proj": "mlp",
|
|
||||||
"pooler_type": "mean_pooler"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"quick_gelu": true,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 768,
|
|
||||||
"patch_size": 32
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"hf_model_name": "roberta-base",
|
|
||||||
"hf_tokenizer_name": "roberta-base",
|
|
||||||
"proj": "mlp",
|
|
||||||
"pooler_type": "mean_pooler"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 640,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "swin_base_patch4_window7_224",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "linear",
|
|
||||||
"image_size": 224
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 640,
|
|
||||||
"heads": 10,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "vit_medium_patch16_gap_256",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "linear",
|
|
||||||
"image_size": 256
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"timm_model_name": "vit_relpos_medium_patch16_cls_224",
|
|
||||||
"timm_model_pretrained": false,
|
|
||||||
"timm_pool": "",
|
|
||||||
"timm_proj": "linear",
|
|
||||||
"image_size": 224
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"context_length": 77,
|
|
||||||
"vocab_size": 49408,
|
|
||||||
"width": 512,
|
|
||||||
"heads": 8,
|
|
||||||
"layers": 12
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 512,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 12,
|
|
||||||
"width": 768,
|
|
||||||
"patch_size": 32
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"hf_model_name": "xlm-roberta-base",
|
|
||||||
"hf_tokenizer_name": "xlm-roberta-base",
|
|
||||||
"proj": "mlp",
|
|
||||||
"pooler_type": "mean_pooler"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"vision_cfg": {
|
|
||||||
"image_size": 224,
|
|
||||||
"layers": 32,
|
|
||||||
"width": 1280,
|
|
||||||
"head_width": 80,
|
|
||||||
"patch_size": 14
|
|
||||||
},
|
|
||||||
"text_cfg": {
|
|
||||||
"hf_model_name": "xlm-roberta-large",
|
|
||||||
"hf_tokenizer_name": "xlm-roberta-large",
|
|
||||||
"proj": "mlp",
|
|
||||||
"pooler_type": "mean_pooler"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -19,7 +19,10 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def default_bpe():
|
def default_bpe():
|
||||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
||||||
|
quality_metric_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||||
|
return os.path.join(quality_metric_path, "bpe_simple_vocab_16e6.txt.gz")
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
|
|||||||
@@ -6,15 +6,15 @@ import os
|
|||||||
from .config import MODEL_PATHS
|
from .config import MODEL_PATHS
|
||||||
|
|
||||||
class PickScore:
|
class PickScore:
|
||||||
def __init__(self, device: Union[str, torch.device]):
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
||||||
"""Initialize the Selector with a processor and model.
|
"""Initialize the Selector with a processor and model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device (Union[str, torch.device]): The device to load the model on.
|
device (Union[str, torch.device]): The device to load the model on.
|
||||||
"""
|
"""
|
||||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||||
processor_name_or_path = MODEL_PATHS.get("clip")
|
processor_name_or_path = path.get("clip")
|
||||||
model_pretrained_name_or_path = MODEL_PATHS.get("pickscore")
|
model_pretrained_name_or_path = path.get("pickscore")
|
||||||
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
|
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
|
||||||
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
|
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .base_model import *
|
||||||
|
from .clip_model import *
|
||||||
|
from .cross_modeling import *
|
||||||
@@ -4,13 +4,13 @@ from transformers import AutoTokenizer
|
|||||||
|
|
||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
|
|
||||||
from trainer.models.base_model import BaseModelConfig
|
from .base_model import BaseModelConfig
|
||||||
|
|
||||||
from transformers import CLIPConfig
|
from transformers import CLIPConfig
|
||||||
from typing import Any, Optional, Tuple, Union
|
from typing import Any, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from trainer.models.cross_modeling import Cross_model
|
from .cross_modeling import Cross_model
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ class XCLIPModel(HFCLIPModel):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClipModelConfig(BaseModelConfig):
|
class ClipModelConfig(BaseModelConfig):
|
||||||
_target_: str = "trainer.models.clip_model.CLIPModel"
|
_target_: str = "diffsynth.extensions.QualityMetric.trainer.models.clip_model.CLIPModel"
|
||||||
pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"
|
pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,17 +8,48 @@ from diffsynth.extensions.QualityMetric.clip import CLIPScore
|
|||||||
from diffsynth.extensions.QualityMetric.hps import HPScore_v2
|
from diffsynth.extensions.QualityMetric.hps import HPScore_v2
|
||||||
from diffsynth.extensions.QualityMetric.mps import MPScore
|
from diffsynth.extensions.QualityMetric.mps import MPScore
|
||||||
|
|
||||||
|
# download model from modelscope
|
||||||
|
from modelscope.hub.snapshot_download import snapshot_download
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, '../../'))
|
||||||
|
model_folder = os.path.join(project_root, 'models', 'QualityMetric')
|
||||||
|
|
||||||
|
# download HPS_v2 to your folder
|
||||||
|
# model_id = "DiffSynth-Studio/QualityMetric_reward_pretrained"
|
||||||
|
# downloaded_path = snapshot_download(
|
||||||
|
# model_id,
|
||||||
|
# cache_dir=os.path.join(model_folder, 'HPS_v2'),
|
||||||
|
# allow_patterns=["HPS_v2/*"],
|
||||||
|
# )
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
def get_model_path(model_folder, model_name):
|
||||||
|
return os.path.join(model_folder, model_name)
|
||||||
|
|
||||||
|
# your model path
|
||||||
|
model_path = {
|
||||||
|
"aesthetic_predictor": get_model_path(model_folder, "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
|
||||||
|
"open_clip": get_model_path(model_folder, "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
|
||||||
|
"hpsv2": get_model_path(model_folder, "HPS_v2/HPS_v2_compressed.safetensors"),
|
||||||
|
"hpsv2.1": get_model_path(model_folder, "HPS_v2/HPS_v2.1_compressed.safetensors"),
|
||||||
|
"imagereward": get_model_path(model_folder, "ImageReward/ImageReward.safetensors"),
|
||||||
|
"med_config": get_model_path(model_folder, "ImageReward/med_config.json"),
|
||||||
|
"clip": get_model_path(model_folder, "CLIP-ViT-H-14-laion2B-s32B-b79K"),
|
||||||
|
"clip-large": get_model_path(model_folder, "clip-vit-large-patch14"),
|
||||||
|
"mps": get_model_path(model_folder, "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
|
||||||
|
"pickscore": get_model_path(model_folder, "PickScore_v1")
|
||||||
|
}
|
||||||
|
|
||||||
# load reward models
|
# load reward models
|
||||||
mps_score = MPScore(device)
|
mps_score = MPScore(device,path = model_path)
|
||||||
image_reward = ImageRewardScore(device)
|
image_reward = ImageRewardScore(device, path = model_path)
|
||||||
aesthetic_score = AestheticScore(device)
|
aesthetic_score = AestheticScore(device, path = model_path)
|
||||||
pick_score = PickScore(device)
|
pick_score = PickScore(device, path = model_path)
|
||||||
clip_score = CLIPScore(device)
|
clip_score = CLIPScore(device, path = model_path)
|
||||||
hps_score = HPScore_v2(device, model_version = 'v2')
|
hps_score = HPScore_v2(device, path = model_path, model_version = 'v2')
|
||||||
hps2_score = HPScore_v2(device, model_version = 'v21')
|
hps2_score = HPScore_v2(device, path = model_path, model_version = 'v21')
|
||||||
|
|
||||||
prompt = "a painting of an ocean with clouds and birds, day time, low depth field effect"
|
prompt = "a painting of an ocean with clouds and birds, day time, low depth field effect"
|
||||||
img_prefix = "images"
|
img_prefix = "images"
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseModelConfig:
|
|
||||||
pass
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from transformers import CLIPModel as HFCLIPModel
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from torch import nn, einsum
|
|
||||||
|
|
||||||
from trainer.models.base_model import BaseModelConfig
|
|
||||||
|
|
||||||
from transformers import CLIPConfig
|
|
||||||
from typing import Any, Optional, Tuple, Union
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from trainer.models.cross_modeling import Cross_model
|
|
||||||
|
|
||||||
import gc
|
|
||||||
|
|
||||||
class XCLIPModel(HFCLIPModel):
|
|
||||||
def __init__(self, config: CLIPConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
def get_text_features(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
|
|
||||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
text_outputs = self.text_model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
# pooled_output = text_outputs[1]
|
|
||||||
# text_features = self.text_projection(pooled_output)
|
|
||||||
last_hidden_state = text_outputs[0]
|
|
||||||
text_features = self.text_projection(last_hidden_state)
|
|
||||||
|
|
||||||
pooled_output = text_outputs[1]
|
|
||||||
text_features_EOS = self.text_projection(pooled_output)
|
|
||||||
|
|
||||||
|
|
||||||
# del last_hidden_state, text_outputs
|
|
||||||
# gc.collect()
|
|
||||||
|
|
||||||
return text_features, text_features_EOS
|
|
||||||
|
|
||||||
def get_image_features(
|
|
||||||
self,
|
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
|
|
||||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
vision_outputs = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
# pooled_output = vision_outputs[1] # pooled_output
|
|
||||||
# image_features = self.visual_projection(pooled_output)
|
|
||||||
last_hidden_state = vision_outputs[0]
|
|
||||||
image_features = self.visual_projection(last_hidden_state)
|
|
||||||
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClipModelConfig(BaseModelConfig):
|
|
||||||
_target_: str = "trainer.models.clip_model.CLIPModel"
|
|
||||||
pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPModel(nn.Module):
|
|
||||||
def __init__(self, ckpt):
|
|
||||||
super().__init__()
|
|
||||||
self.model = XCLIPModel.from_pretrained(ckpt)
|
|
||||||
self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
|
|
||||||
|
|
||||||
def get_text_features(self, *args, **kwargs):
|
|
||||||
return self.model.get_text_features(*args, **kwargs)
|
|
||||||
|
|
||||||
def get_image_features(self, *args, **kwargs):
|
|
||||||
return self.model.get_image_features(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
|
|
||||||
outputs = ()
|
|
||||||
|
|
||||||
text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
|
|
||||||
outputs += text_EOS,
|
|
||||||
|
|
||||||
image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
|
|
||||||
condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024
|
|
||||||
|
|
||||||
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
|
||||||
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
|
||||||
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
|
||||||
mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77
|
|
||||||
|
|
||||||
mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
|
|
||||||
bc = int(image_f.shape[0]/2)
|
|
||||||
|
|
||||||
sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
|
|
||||||
sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
|
|
||||||
outputs += sim0[:,0,:],
|
|
||||||
outputs += sim1[:,0,:],
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
@property
|
|
||||||
def logit_scale(self):
|
|
||||||
return self.model.logit_scale
|
|
||||||
|
|
||||||
def save(self, path):
|
|
||||||
self.model.save_pretrained(path)
|
|
||||||
|
|
||||||
@@ -1,292 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch import einsum, nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
# helper functions
|
|
||||||
|
|
||||||
def exists(val):
|
|
||||||
return val is not None
|
|
||||||
|
|
||||||
def default(val, d):
|
|
||||||
return val if exists(val) else d
|
|
||||||
|
|
||||||
# normalization
|
|
||||||
# they use layernorm without bias, something that pytorch does not offer
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
|
||||||
def __init__(self, dim):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
self.register_buffer("bias", torch.zeros(dim))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)
|
|
||||||
|
|
||||||
# residual
|
|
||||||
|
|
||||||
|
|
||||||
class Residual(nn.Module):
|
|
||||||
def __init__(self, fn):
|
|
||||||
super().__init__()
|
|
||||||
self.fn = fn
|
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
|
||||||
return self.fn(x, *args, **kwargs) + x
|
|
||||||
|
|
||||||
|
|
||||||
# rotary positional embedding
|
|
||||||
# https://arxiv.org/abs/2104.09864
|
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
|
||||||
def __init__(self, dim):
|
|
||||||
super().__init__()
|
|
||||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
|
||||||
|
|
||||||
def forward(self, max_seq_len, *, device):
|
|
||||||
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
|
|
||||||
freqs = einsum("i , j -> i j", seq, self.inv_freq)
|
|
||||||
return torch.cat((freqs, freqs), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
|
||||||
x1, x2 = x.unbind(dim=-2)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(pos, t):
|
|
||||||
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
|
|
||||||
|
|
||||||
|
|
||||||
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
|
|
||||||
# https://arxiv.org/abs/2002.05202
|
|
||||||
|
|
||||||
|
|
||||||
class SwiGLU(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
x, gate = x.chunk(2, dim=-1)
|
|
||||||
return F.silu(gate) * x
|
|
||||||
|
|
||||||
|
|
||||||
# parallel attention and feedforward with residual
|
|
||||||
# discovered by Wang et al + EleutherAI from GPT-J fame
|
|
||||||
|
|
||||||
class ParallelTransformerBlock(nn.Module):
|
|
||||||
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
|
|
||||||
super().__init__()
|
|
||||||
self.norm = LayerNorm(dim)
|
|
||||||
|
|
||||||
attn_inner_dim = dim_head * heads
|
|
||||||
ff_inner_dim = dim * ff_mult
|
|
||||||
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
|
|
||||||
|
|
||||||
self.heads = heads
|
|
||||||
self.scale = dim_head**-0.5
|
|
||||||
self.rotary_emb = RotaryEmbedding(dim_head)
|
|
||||||
|
|
||||||
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
|
|
||||||
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
|
|
||||||
|
|
||||||
self.ff_out = nn.Sequential(
|
|
||||||
SwiGLU(),
|
|
||||||
nn.Linear(ff_inner_dim, dim, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.register_buffer("pos_emb", None, persistent=False)
|
|
||||||
|
|
||||||
|
|
||||||
def get_rotary_embedding(self, n, device):
|
|
||||||
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
|
|
||||||
return self.pos_emb[:n]
|
|
||||||
|
|
||||||
pos_emb = self.rotary_emb(n, device=device)
|
|
||||||
self.register_buffer("pos_emb", pos_emb, persistent=False)
|
|
||||||
return pos_emb
|
|
||||||
|
|
||||||
def forward(self, x, attn_mask=None):
|
|
||||||
"""
|
|
||||||
einstein notation
|
|
||||||
b - batch
|
|
||||||
h - heads
|
|
||||||
n, i, j - sequence length (base sequence length, source, target)
|
|
||||||
d - feature dimension
|
|
||||||
"""
|
|
||||||
|
|
||||||
n, device, h = x.shape[1], x.device, self.heads
|
|
||||||
|
|
||||||
# pre layernorm
|
|
||||||
|
|
||||||
x = self.norm(x)
|
|
||||||
|
|
||||||
# attention queries, keys, values, and feedforward inner
|
|
||||||
|
|
||||||
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
|
|
||||||
|
|
||||||
# split heads
|
|
||||||
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
|
||||||
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
|
||||||
# https://arxiv.org/abs/1911.02150
|
|
||||||
|
|
||||||
q = rearrange(q, "b n (h d) -> b h n d", h=h)
|
|
||||||
|
|
||||||
# rotary embeddings
|
|
||||||
|
|
||||||
positions = self.get_rotary_embedding(n, device)
|
|
||||||
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
|
|
||||||
|
|
||||||
# scale
|
|
||||||
|
|
||||||
q = q * self.scale
|
|
||||||
|
|
||||||
# similarity
|
|
||||||
|
|
||||||
sim = einsum("b h i d, b j d -> b h i j", q, k)
|
|
||||||
|
|
||||||
|
|
||||||
# extra attention mask - for masking out attention from text CLS token to padding
|
|
||||||
|
|
||||||
if exists(attn_mask):
|
|
||||||
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
|
|
||||||
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
|
|
||||||
|
|
||||||
# attention
|
|
||||||
|
|
||||||
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
|
||||||
attn = sim.softmax(dim=-1)
|
|
||||||
|
|
||||||
# aggregate values
|
|
||||||
|
|
||||||
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
|
||||||
|
|
||||||
# merge heads
|
|
||||||
|
|
||||||
out = rearrange(out, "b h n d -> b n (h d)")
|
|
||||||
return self.attn_out(out) + self.ff_out(ff)
|
|
||||||
|
|
||||||
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
*,
|
|
||||||
context_dim=None,
|
|
||||||
dim_head=64,
|
|
||||||
heads=12,
|
|
||||||
parallel_ff=False,
|
|
||||||
ff_mult=4,
|
|
||||||
norm_context=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.heads = heads
|
|
||||||
self.scale = dim_head ** -0.5
|
|
||||||
inner_dim = heads * dim_head
|
|
||||||
context_dim = default(context_dim, dim)
|
|
||||||
|
|
||||||
self.norm = LayerNorm(dim)
|
|
||||||
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
|
|
||||||
|
|
||||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
|
||||||
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
|
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
||||||
|
|
||||||
# whether to have parallel feedforward
|
|
||||||
|
|
||||||
ff_inner_dim = ff_mult * dim
|
|
||||||
|
|
||||||
self.ff = nn.Sequential(
|
|
||||||
nn.Linear(dim, ff_inner_dim * 2, bias=False),
|
|
||||||
SwiGLU(),
|
|
||||||
nn.Linear(ff_inner_dim, dim, bias=False)
|
|
||||||
) if parallel_ff else None
|
|
||||||
|
|
||||||
def forward(self, x, context, mask):
|
|
||||||
"""
|
|
||||||
einstein notation
|
|
||||||
b - batch
|
|
||||||
h - heads
|
|
||||||
n, i, j - sequence length (base sequence length, source, target)
|
|
||||||
d - feature dimension
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pre-layernorm, for queries and context
|
|
||||||
|
|
||||||
x = self.norm(x)
|
|
||||||
context = self.context_norm(context)
|
|
||||||
|
|
||||||
# get queries
|
|
||||||
|
|
||||||
q = self.to_q(x)
|
|
||||||
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
|
||||||
|
|
||||||
# scale
|
|
||||||
|
|
||||||
q = q * self.scale
|
|
||||||
|
|
||||||
# get key / values
|
|
||||||
|
|
||||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
|
||||||
|
|
||||||
# query / key similarity
|
|
||||||
|
|
||||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
|
||||||
|
|
||||||
# attention
|
|
||||||
mask = mask.unsqueeze(1).repeat(1,self.heads,1,1)
|
|
||||||
sim = sim + mask # context mask
|
|
||||||
sim = sim - sim.amax(dim=-1, keepdim=True)
|
|
||||||
attn = sim.softmax(dim=-1)
|
|
||||||
|
|
||||||
# aggregate
|
|
||||||
|
|
||||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
|
||||||
|
|
||||||
# merge and combine heads
|
|
||||||
|
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
||||||
out = self.to_out(out)
|
|
||||||
|
|
||||||
# add parallel feedforward (for multimodal layers)
|
|
||||||
|
|
||||||
if exists(self.ff):
|
|
||||||
out = out + self.ff(x)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Cross_model(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim=512,
|
|
||||||
layer_num=4,
|
|
||||||
dim_head=64,
|
|
||||||
heads=8,
|
|
||||||
ff_mult=4
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
|
||||||
|
|
||||||
|
|
||||||
for ind in range(layer_num):
|
|
||||||
self.layers.append(nn.ModuleList([
|
|
||||||
Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
|
|
||||||
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
|
|
||||||
]))
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
query_tokens,
|
|
||||||
context_tokens,
|
|
||||||
mask
|
|
||||||
):
|
|
||||||
|
|
||||||
for cross_attn, self_attn_ff in self.layers:
|
|
||||||
query_tokens = cross_attn(query_tokens, context_tokens,mask)
|
|
||||||
query_tokens = self_attn_ff(query_tokens)
|
|
||||||
|
|
||||||
return query_tokens
|
|
||||||
Reference in New Issue
Block a user