image quality metric

This commit is contained in:
hongzhang.hz
2025-02-14 12:39:06 +08:00
parent f47de78b59
commit acda7d891a
100 changed files with 8336 additions and 0 deletions

View File

@@ -0,0 +1,54 @@
# Image Quality Metric
The image quality assessment functionality has now been integrated into Diffsynth.
## Usage
### Step 1: Download pretrained reward models
```
modelscope download --model 'DiffSynth-Studio/QualityMetric_reward_pretrained'
```
The file directory is shown below.
```
DiffSynth-Studio/
└── diffsynth/
└── extensions/
└── QualityMetric/
├── __init__.py
├── hps.py
├── reward_pretrained/
│ ├── HPS_v2/
│ │ ├── HPS_v2_compressed.safetensors
│ │ ├── HPS_v2.1_compressed.safetensors
│ └── ...
└── ...
```
### Step 2: Test image quality metric
Prompt: "a painting of an ocean with clouds and birds, day time, low depth field effect"
|1.webp|2.webp|3.webp|4.webp|
|-|-|-|-|
|![0](images/1.webp)|![1](images/2.webp)|![2](images/3.webp)|![3](images/4.webp)|
```
CUDA_VISIBLE_DEVICES=0 python testreward.py
```
### Output:
```
ImageReward: [0.5811904668807983, 0.2745198607444763, -1.4158903360366821, -2.032487154006958]
Aesthetic [5.900862693786621, 5.776571273803711, 5.799864292144775, 5.05204963684082]
PickScore: [0.20737126469612122, 0.20443597435951233, 0.20660750567913055, 0.19426065683364868]
CLIPScore: [0.3894640803337097, 0.3544551134109497, 0.33861416578292847, 0.32878392934799194]
HPScorev2: [0.2672519087791443, 0.25495243072509766, 0.24888549745082855, 0.24302822351455688]
HPScorev21: [0.2321144938468933, 0.20233657956123352, 0.1978294551372528, 0.19230154156684875]
MPS_score: [10.921875, 10.71875, 10.578125, 9.25]
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 329 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 250 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 275 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 311 KiB

View File

@@ -0,0 +1,49 @@
import os
import torch
from PIL import Image
from diffsynth.extensions.QualityMetric.imagereward import ImageRewardScore
from diffsynth.extensions.QualityMetric.pickscore import PickScore
from diffsynth.extensions.QualityMetric.aesthetic import AestheticScore
from diffsynth.extensions.QualityMetric.clip import CLIPScore
from diffsynth.extensions.QualityMetric.hps import HPScore_v2
from diffsynth.extensions.QualityMetric.mps import MPScore
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load reward models
mps_score = MPScore(device)
image_reward = ImageRewardScore(device)
aesthetic_score = AestheticScore(device)
pick_score = PickScore(device)
clip_score = CLIPScore(device)
hps_score = HPScore_v2(device, model_version = 'v2')
hps2_score = HPScore_v2(device, model_version = 'v21')
prompt = "a painting of an ocean with clouds and birds, day time, low depth field effect"
img_prefix = "images"
generations = [f"{pic_id}.webp" for pic_id in range(1, 5)]
img_list = [Image.open(os.path.join(img_prefix, img)) for img in generations]
#img_list = [os.path.join(img_prefix, img) for img in generations]
imre_scores = image_reward.score(img_list, prompt)
print("ImageReward:", imre_scores)
aes_scores = aesthetic_score.score(img_list)
print("Aesthetic", aes_scores)
p_scores = pick_score.score(img_list, prompt)
print("PickScore:", p_scores)
c_scores = clip_score.score(img_list, prompt)
print("CLIPScore:", c_scores)
h_scores = hps_score.score(img_list,prompt)
print("HPScorev2:", h_scores)
h2_scores = hps2_score.score(img_list,prompt)
print("HPScorev21:", h2_scores)
m_scores = mps_score.score(img_list, prompt)
print("MPS_score:", m_scores)

View File

@@ -0,0 +1,7 @@
from dataclasses import dataclass
@dataclass
class BaseModelConfig:
pass

View File

@@ -0,0 +1,140 @@
from dataclasses import dataclass
from transformers import CLIPModel as HFCLIPModel
from transformers import AutoTokenizer
from torch import nn, einsum
from trainer.models.base_model import BaseModelConfig
from transformers import CLIPConfig
from typing import Any, Optional, Tuple, Union
import torch
from trainer.models.cross_modeling import Cross_model
import gc
class XCLIPModel(HFCLIPModel):
def __init__(self, config: CLIPConfig):
super().__init__(config)
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# pooled_output = text_outputs[1]
# text_features = self.text_projection(pooled_output)
last_hidden_state = text_outputs[0]
text_features = self.text_projection(last_hidden_state)
pooled_output = text_outputs[1]
text_features_EOS = self.text_projection(pooled_output)
# del last_hidden_state, text_outputs
# gc.collect()
return text_features, text_features_EOS
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# pooled_output = vision_outputs[1] # pooled_output
# image_features = self.visual_projection(pooled_output)
last_hidden_state = vision_outputs[0]
image_features = self.visual_projection(last_hidden_state)
return image_features
@dataclass
class ClipModelConfig(BaseModelConfig):
_target_: str = "trainer.models.clip_model.CLIPModel"
pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"
class CLIPModel(nn.Module):
def __init__(self, ckpt):
super().__init__()
self.model = XCLIPModel.from_pretrained(ckpt)
self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
def get_text_features(self, *args, **kwargs):
return self.model.get_text_features(*args, **kwargs)
def get_image_features(self, *args, **kwargs):
return self.model.get_image_features(*args, **kwargs)
def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
outputs = ()
text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
outputs += text_EOS,
image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
sim_text_condition = sim_text_condition / sim_text_condition.max()
mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77
mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
bc = int(image_f.shape[0]/2)
sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
outputs += sim0[:,0,:],
outputs += sim1[:,0,:],
return outputs
@property
def logit_scale(self):
return self.model.logit_scale
def save(self, path):
self.model.save_pretrained(path)

View File

@@ -0,0 +1,292 @@
import torch
from torch import einsum, nn
import torch.nn.functional as F
from einops import rearrange, repeat
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# normalization
# they use layernorm without bias, something that pytorch does not offer
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.register_buffer("bias", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)
# residual
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
# rotary positional embedding
# https://arxiv.org/abs/2104.09864
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame
class ParallelTransformerBlock(nn.Module):
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
super().__init__()
self.norm = LayerNorm(dim)
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
self.heads = heads
self.scale = dim_head**-0.5
self.rotary_emb = RotaryEmbedding(dim_head)
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
self.ff_out = nn.Sequential(
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
)
self.register_buffer("pos_emb", None, persistent=False)
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n]
pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("pos_emb", pos_emb, persistent=False)
return pos_emb
def forward(self, x, attn_mask=None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device, h = x.shape[1], x.device, self.heads
# pre layernorm
x = self.norm(x)
# attention queries, keys, values, and feedforward inner
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# https://arxiv.org/abs/1911.02150
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# rotary embeddings
positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
# scale
q = q * self.scale
# similarity
sim = einsum("b h i d, b j d -> b h i j", q, k)
# extra attention mask - for masking out attention from text CLS token to padding
if exists(attn_mask):
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
# attention
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
# aggregate values
out = einsum("b h i j, b j d -> b h i d", attn, v)
# merge heads
out = rearrange(out, "b h n d -> b n (h d)")
return self.attn_out(out) + self.ff_out(ff)
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim=None,
dim_head=64,
heads=12,
parallel_ff=False,
ff_mult=4,
norm_context=False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
# whether to have parallel feedforward
ff_inner_dim = ff_mult * dim
self.ff = nn.Sequential(
nn.Linear(dim, ff_inner_dim * 2, bias=False),
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
) if parallel_ff else None
def forward(self, x, context, mask):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# pre-layernorm, for queries and context
x = self.norm(x)
context = self.context_norm(context)
# get queries
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
# scale
q = q * self.scale
# get key / values
k, v = self.to_kv(context).chunk(2, dim=-1)
# query / key similarity
sim = einsum('b h i d, b j d -> b h i j', q, k)
# attention
mask = mask.unsqueeze(1).repeat(1,self.heads,1,1)
sim = sim + mask # context mask
sim = sim - sim.amax(dim=-1, keepdim=True)
attn = sim.softmax(dim=-1)
# aggregate
out = einsum('b h i j, b j d -> b h i d', attn, v)
# merge and combine heads
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
# add parallel feedforward (for multimodal layers)
if exists(self.ff):
out = out + self.ff(x)
return out
class Cross_model(nn.Module):
def __init__(
self,
dim=512,
layer_num=4,
dim_head=64,
heads=8,
ff_mult=4
):
super().__init__()
self.layers = nn.ModuleList([])
for ind in range(layer_num):
self.layers.append(nn.ModuleList([
Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
]))
def forward(
self,
query_tokens,
context_tokens,
mask
):
for cross_attn, self_attn_ff in self.layers:
query_tokens = cross_attn(query_tokens, context_tokens,mask)
query_tokens = self_attn_ff(query_tokens)
return query_tokens