diff --git a/diffsynth/extensions/QualityMetric/BLIP/__init__.py b/diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py similarity index 100% rename from diffsynth/extensions/QualityMetric/BLIP/__init__.py rename to diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py diff --git a/diffsynth/extensions/QualityMetric/BLIP/blip.py b/diffsynth/extensions/ImageQualityMetric/BLIP/blip.py similarity index 98% rename from diffsynth/extensions/QualityMetric/BLIP/blip.py rename to diffsynth/extensions/ImageQualityMetric/BLIP/blip.py index 2e100da..6b24c3c 100644 --- a/diffsynth/extensions/QualityMetric/BLIP/blip.py +++ b/diffsynth/extensions/ImageQualityMetric/BLIP/blip.py @@ -19,9 +19,8 @@ def default_bert(): model_path = os.path.join(project_root, 'models', 'QualityMetric') return os.path.join(model_path, "bert-base-uncased") -bert_model_path = default_bert() -def init_tokenizer(): +def init_tokenizer(bert_model_path): tokenizer = BertTokenizer.from_pretrained(bert_model_path) tokenizer.add_special_tokens({'bos_token':'[DEC]'}) tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) diff --git a/diffsynth/extensions/QualityMetric/BLIP/blip_pretrain.py b/diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py similarity index 93% rename from diffsynth/extensions/QualityMetric/BLIP/blip_pretrain.py rename to diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py index 793cb07..ba711e2 100644 --- a/diffsynth/extensions/QualityMetric/BLIP/blip_pretrain.py +++ b/diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py @@ -20,6 +20,7 @@ class BLIP_Pretrain(nn.Module): embed_dim = 256, queue_size = 57600, momentum = 0.995, + bert_model_path = "" ): """ Args: @@ -31,7 +32,7 @@ class BLIP_Pretrain(nn.Module): self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0) - self.tokenizer = init_tokenizer() + self.tokenizer = init_tokenizer(bert_model_path) encoder_config = BertConfig.from_json_file(med_config) encoder_config.encoder_width = vision_width self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) diff --git a/diffsynth/extensions/QualityMetric/BLIP/med.py b/diffsynth/extensions/ImageQualityMetric/BLIP/med.py similarity index 100% rename from diffsynth/extensions/QualityMetric/BLIP/med.py rename to diffsynth/extensions/ImageQualityMetric/BLIP/med.py diff --git a/diffsynth/extensions/QualityMetric/BLIP/vit.py b/diffsynth/extensions/ImageQualityMetric/BLIP/vit.py similarity index 98% rename from diffsynth/extensions/QualityMetric/BLIP/vit.py rename to diffsynth/extensions/ImageQualityMetric/BLIP/vit.py index 7e5cf43..cef7b65 100644 --- a/diffsynth/extensions/QualityMetric/BLIP/vit.py +++ b/diffsynth/extensions/ImageQualityMetric/BLIP/vit.py @@ -14,7 +14,7 @@ from timm.models.registry import register_model from timm.models.layers import trunc_normal_, DropPath from timm.models.helpers import named_apply, adapt_input_conv -from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper +# from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks @@ -96,9 +96,9 @@ class Block(nn.Module): mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - if use_grad_checkpointing: - self.attn = checkpoint_wrapper(self.attn) - self.mlp = checkpoint_wrapper(self.mlp) + # if use_grad_checkpointing: + # self.attn = checkpoint_wrapper(self.attn) + # self.mlp = checkpoint_wrapper(self.mlp) def forward(self, x, register_hook=False): x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) diff --git a/diffsynth/extensions/ImageQualityMetric/__init__.py b/diffsynth/extensions/ImageQualityMetric/__init__.py new file mode 100644 index 0000000..fcfb7c0 --- /dev/null +++ b/diffsynth/extensions/ImageQualityMetric/__init__.py @@ -0,0 +1,148 @@ +from modelscope import snapshot_download +from typing_extensions import Literal, TypeAlias +import os +from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore +from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore +from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore +from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore +from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2 +from diffsynth.extensions.ImageQualityMetric.mps import MPScore + + +preference_model_id: TypeAlias = Literal[ + "ImageReward", + "Aesthetic", + "PickScore", + "CLIP", + "HPSv2", + "HPSv2.1", + "MPS", +] +model_dict = { + "ImageReward": { + "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained", + "allow_file_pattern": [ + "ImageReward/ImageReward.safetensors", + "ImageReward/med_config.json", + "bert-base-uncased/config.json", + "bert-base-uncased/model.safetensors", + "bert-base-uncased/tokenizer.json", + "bert-base-uncased/tokenizer_config.json", + "bert-base-uncased/vocab.txt", + ], + "load_path": { + "imagereward": "ImageReward/ImageReward.safetensors", + "med_config": "ImageReward/med_config.json", + "bert_model_path": "bert-base-uncased", + }, + "model_class": ImageRewardScore + }, + "Aesthetic": { + "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained", + "allow_file_pattern": [ + "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors", + "clip-vit-large-patch14/config.json", + "clip-vit-large-patch14/merges.txt", + "clip-vit-large-patch14/model.safetensors", + "clip-vit-large-patch14/preprocessor_config.json", + "clip-vit-large-patch14/special_tokens_map.json", + "clip-vit-large-patch14/tokenizer.json", + "clip-vit-large-patch14/tokenizer_config.json", + "clip-vit-large-patch14/vocab.json", + ], + "load_path": { + "aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors", + "clip-large": "clip-vit-large-patch14", + }, + "model_class": AestheticScore + }, + "PickScore": { + "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained", + "allow_file_pattern": [ + "PickScore_v1/*", + "CLIP-ViT-H-14-laion2B-s32B-b79K/config.json", + "CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt", + "CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json", + "CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json", + "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json", + "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json", + "CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json", + ], + "load_path": { + "pickscore": "PickScore_v1", + "clip": "CLIP-ViT-H-14-laion2B-s32B-b79K", + }, + "model_class": PickScore + }, + "CLIP": { + "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained", + "allow_file_pattern": [ + "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin", + "bpe_simple_vocab_16e6.txt.gz", + ], + "load_path": { + "open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin", + "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz", + }, + "model_class": CLIPScore + }, + "HPSv2": { + "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained", + "allow_file_pattern": [ + "HPS_v2/HPS_v2_compressed.safetensors", + "bpe_simple_vocab_16e6.txt.gz", + ], + "load_path": { + "hpsv2": "HPS_v2/HPS_v2_compressed.safetensors", + "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz", + }, + "model_class": HPScore_v2, + "extra_kwargs": {"model_version": "v2"} + }, + "HPSv2.1": { + "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained", + "allow_file_pattern": [ + "HPS_v2/HPS_v2.1_compressed.safetensors", + "bpe_simple_vocab_16e6.txt.gz", + ], + "load_path": { + "hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors", + "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz", + }, + "model_class": HPScore_v2, + "extra_kwargs": {"model_version": "v21"} + }, + "MPS": { + "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained", + "allow_file_pattern": [ + "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors", + "CLIP-ViT-H-14-laion2B-s32B-b79K/config.json", + "CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt", + "CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json", + "CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json", + "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json", + "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json", + "CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json", + ], + "load_path": { + "mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors", + "clip": "CLIP-ViT-H-14-laion2B-s32B-b79K", + }, + "model_class": MPScore + }, +} + + +def download_preference_model(model_name: preference_model_id, cache_dir="models"): + metadata = model_dict[model_name] + snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir) + load_path = metadata["load_path"] + load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()} + return load_path + + +def load_preference_model(model_name: preference_model_id, device = "cuda", path = None): + model_class = model_dict[model_name]["model_class"] + extra_kwargs = model_dict[model_name].get("extra_kwargs", {}) + preference_model = model_class(device=device, path=path, **extra_kwargs) + return preference_model diff --git a/diffsynth/extensions/QualityMetric/aesthetic.py b/diffsynth/extensions/ImageQualityMetric/aesthetic.py similarity index 96% rename from diffsynth/extensions/QualityMetric/aesthetic.py rename to diffsynth/extensions/ImageQualityMetric/aesthetic.py index 46e3c57..13da98a 100644 --- a/diffsynth/extensions/QualityMetric/aesthetic.py +++ b/diffsynth/extensions/ImageQualityMetric/aesthetic.py @@ -49,13 +49,9 @@ class MLP(torch.nn.Module): return torch.optim.Adam(self.parameters(), lr=1e-3) -class AestheticScore: +class AestheticScore(torch.nn.Module): def __init__(self, device: torch.device, path: str = MODEL_PATHS): - """Initialize the Selector with a model and processor. - - Args: - device (torch.device): The device to load the model on. - """ + super().__init__() self.device = device self.aes_model_path = path.get("aesthetic_predictor") # Load the MLP model @@ -96,7 +92,8 @@ class AestheticScore: return score - def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]]) -> List[float]: + @torch.no_grad() + def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]: """Score the images based on their aesthetic quality. Args: diff --git a/diffsynth/extensions/QualityMetric/clip.py b/diffsynth/extensions/ImageQualityMetric/clip.py similarity index 58% rename from diffsynth/extensions/QualityMetric/clip.py rename to diffsynth/extensions/ImageQualityMetric/clip.py index ab6bdef..f70941e 100644 --- a/diffsynth/extensions/QualityMetric/clip.py +++ b/diffsynth/extensions/ImageQualityMetric/clip.py @@ -4,8 +4,9 @@ import torch from .open_clip import create_model_and_transforms, get_tokenizer from .config import MODEL_PATHS -class CLIPScore: +class CLIPScore(torch.nn.Module): def __init__(self, device: torch.device, path: str = MODEL_PATHS): + super().__init__() """Initialize the CLIPScore with a model and tokenizer. Args: @@ -36,7 +37,7 @@ class CLIPScore: ) # Initialize tokenizer - self.tokenizer = get_tokenizer("ViT-H-14") + self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"]) self.model = self.model.to(device) self.model.eval() @@ -62,37 +63,35 @@ class CLIPScore: return clip_score[0].item() - def score(self, img_path: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]: + @torch.no_grad() + def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]: """Score the images based on the prompt. Args: - img_path (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s). + images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s). prompt (str): The prompt text. Returns: List[float]: List of CLIP scores for the images. """ - try: - if isinstance(img_path, (str, Image.Image)): - # Single image - if isinstance(img_path, str): - image = self.preprocess_val(Image.open(img_path)).unsqueeze(0).to(device=self.device, non_blocking=True) - else: - image = self.preprocess_val(img_path).unsqueeze(0).to(device=self.device, non_blocking=True) - return [self._calculate_score(image, prompt)] - elif isinstance(img_path, list): - # Multiple images - scores = [] - for one_img_path in img_path: - if isinstance(one_img_path, str): - image = self.preprocess_val(Image.open(one_img_path)).unsqueeze(0).to(device=self.device, non_blocking=True) - elif isinstance(one_img_path, Image.Image): - image = self.preprocess_val(one_img_path).unsqueeze(0).to(device=self.device, non_blocking=True) - else: - raise TypeError("The type of parameter img_path is illegal.") - scores.append(self._calculate_score(image, prompt)) - return scores + if isinstance(images, (str, Image.Image)): + # Single image + if isinstance(images, str): + image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True) else: - raise TypeError("The type of parameter img_path is illegal.") - except Exception as e: - raise RuntimeError(f"Error in scoring images: {e}") + image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True) + return [self._calculate_score(image, prompt)] + elif isinstance(images, list): + # Multiple images + scores = [] + for one_images in images: + if isinstance(one_images, str): + image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True) + elif isinstance(one_images, Image.Image): + image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True) + else: + raise TypeError("The type of parameter images is illegal.") + scores.append(self._calculate_score(image, prompt)) + return scores + else: + raise TypeError("The type of parameter images is illegal.") diff --git a/diffsynth/extensions/QualityMetric/config.py b/diffsynth/extensions/ImageQualityMetric/config.py similarity index 100% rename from diffsynth/extensions/QualityMetric/config.py rename to diffsynth/extensions/ImageQualityMetric/config.py diff --git a/diffsynth/extensions/QualityMetric/hps.py b/diffsynth/extensions/ImageQualityMetric/hps.py similarity index 74% rename from diffsynth/extensions/QualityMetric/hps.py rename to diffsynth/extensions/ImageQualityMetric/hps.py index 5754956..a4b266b 100644 --- a/diffsynth/extensions/QualityMetric/hps.py +++ b/diffsynth/extensions/ImageQualityMetric/hps.py @@ -6,8 +6,9 @@ from safetensors.torch import load_file import os from .config import MODEL_PATHS -class HPScore_v2: +class HPScore_v2(torch.nn.Module): def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"): + super().__init__() """Initialize the Selector with a model and tokenizer. Args: @@ -53,7 +54,7 @@ class HPScore_v2: raise ValueError(f"Error loading model weights from {safetensors_path}: {e}") # Initialize tokenizer and model - self.tokenizer = get_tokenizer("ViT-H-14") + self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"]) model = model.to(device) model.eval() self.model = model @@ -80,37 +81,38 @@ class HPScore_v2: return hps_score[0].item() - def score(self, img_path: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]: + @torch.no_grad() + def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]: """Score the images based on the prompt. Args: - img_path (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s). + images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s). prompt (str): The prompt text. Returns: List[float]: List of HPS scores for the images. """ try: - if isinstance(img_path, (str, Image.Image)): + if isinstance(images, (str, Image.Image)): # Single image - if isinstance(img_path, str): - image = self.preprocess_val(Image.open(img_path)).unsqueeze(0).to(device=self.device, non_blocking=True) + if isinstance(images, str): + image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True) else: - image = self.preprocess_val(img_path).unsqueeze(0).to(device=self.device, non_blocking=True) + image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True) return [self._calculate_score(image, prompt)] - elif isinstance(img_path, list): + elif isinstance(images, list): # Multiple images scores = [] - for one_img_path in img_path: - if isinstance(one_img_path, str): - image = self.preprocess_val(Image.open(one_img_path)).unsqueeze(0).to(device=self.device, non_blocking=True) - elif isinstance(one_img_path, Image.Image): - image = self.preprocess_val(one_img_path).unsqueeze(0).to(device=self.device, non_blocking=True) + for one_images in images: + if isinstance(one_images, str): + image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True) + elif isinstance(one_images, Image.Image): + image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True) else: - raise TypeError("The type of parameter img_path is illegal.") + raise TypeError("The type of parameter images is illegal.") scores.append(self._calculate_score(image, prompt)) return scores else: - raise TypeError("The type of parameter img_path is illegal.") + raise TypeError("The type of parameter images is illegal.") except Exception as e: raise RuntimeError(f"Error in scoring images: {e}") diff --git a/diffsynth/extensions/QualityMetric/imagereward.py b/diffsynth/extensions/ImageQualityMetric/imagereward.py similarity index 94% rename from diffsynth/extensions/QualityMetric/imagereward.py rename to diffsynth/extensions/ImageQualityMetric/imagereward.py index e0f5705..2760790 100644 --- a/diffsynth/extensions/QualityMetric/imagereward.py +++ b/diffsynth/extensions/ImageQualityMetric/imagereward.py @@ -52,11 +52,11 @@ class MLP(torch.nn.Module): return self.layers(input) class ImageReward(torch.nn.Module): - def __init__(self, med_config, device='cpu'): + def __init__(self, med_config, device='cpu', bert_model_path=""): super().__init__() self.device = device - self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config) + self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path) self.preprocess = _transform(224) self.mlp = MLP(768) @@ -88,7 +88,7 @@ class ImageReward(torch.nn.Module): rewards = (rewards - self.mean) / self.std return rewards - def score(self, prompt: str, images: Union[str, List[str], Image.Image, List[Image.Image]]) -> List[float]: + def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]: """Score the images based on the prompt. Args: @@ -187,21 +187,18 @@ class ImageReward(torch.nn.Module): return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist() -class ImageRewardScore: +class ImageRewardScore(torch.nn.Module): def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS): - """Initialize the Selector with a processor and model. - - Args: - device (Union[str, torch.device]): The device to load the model on. - """ + super().__init__() self.device = device if isinstance(device, torch.device) else torch.device(device) model_path = path.get("imagereward") med_config = path.get("med_config") state_dict = load_file(model_path) - self.model = ImageReward(device=self.device, med_config=med_config).to(self.device) + self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device) self.model.load_state_dict(state_dict, strict=False) self.model.eval() + @torch.no_grad() def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]: """Score the images based on the prompt. @@ -212,4 +209,4 @@ class ImageRewardScore: Returns: List[float]: List of scores for the images. """ - return self.model.score(prompt, images) + return self.model.score(images, prompt) diff --git a/diffsynth/extensions/QualityMetric/mps.py b/diffsynth/extensions/ImageQualityMetric/mps.py similarity index 73% rename from diffsynth/extensions/QualityMetric/mps.py rename to diffsynth/extensions/ImageQualityMetric/mps.py index e5c7360..d15aad4 100644 --- a/diffsynth/extensions/QualityMetric/mps.py +++ b/diffsynth/extensions/ImageQualityMetric/mps.py @@ -24,8 +24,9 @@ import gc import json from .config import MODEL_PATHS -class MPScore: +class MPScore(torch.nn.Module): def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'): + super().__init__() """Initialize the MPSModel with a processor, tokenizer, and model. Args: @@ -35,7 +36,7 @@ class MPScore: processor_name_or_path = path.get("clip") self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True) - self.model = clip_model.CLIPModel(processor_name_or_path) + self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True) state_dict = load_file(path.get("mps")) self.model.load_state_dict(state_dict, strict=False) self.model.to(device) @@ -94,37 +95,35 @@ class MPScore: return image_score[0].cpu().numpy().item() - def score(self, img_path: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]: + @torch.no_grad() + def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]: """Score the images based on the prompt. Args: - img_path (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s). + images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s). prompt (str): The prompt text. Returns: List[float]: List of reward scores for the images. """ - try: - if isinstance(img_path, (str, Image.Image)): - # Single image - if isinstance(img_path, str): - image = self.image_processor(Image.open(img_path), return_tensors="pt")["pixel_values"].to(self.device) - else: - image = self.image_processor(img_path, return_tensors="pt")["pixel_values"].to(self.device) - return [self._calculate_score(image, prompt)] - elif isinstance(img_path, list): - # Multiple images - scores = [] - for one_img_path in img_path: - if isinstance(one_img_path, str): - image = self.image_processor(Image.open(one_img_path), return_tensors="pt")["pixel_values"].to(self.device) - elif isinstance(one_img_path, Image.Image): - image = self.image_processor(one_img_path, return_tensors="pt")["pixel_values"].to(self.device) - else: - raise TypeError("The type of parameter img_path is illegal.") - scores.append(self._calculate_score(image, prompt)) - return scores + if isinstance(images, (str, Image.Image)): + # Single image + if isinstance(images, str): + image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device) else: - raise TypeError("The type of parameter img_path is illegal.") - except Exception as e: - raise RuntimeError(f"Error in scoring images: {e}") + image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device) + return [self._calculate_score(image, prompt)] + elif isinstance(images, list): + # Multiple images + scores = [] + for one_images in images: + if isinstance(one_images, str): + image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device) + elif isinstance(one_images, Image.Image): + image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device) + else: + raise TypeError("The type of parameter images is illegal.") + scores.append(self._calculate_score(image, prompt)) + return scores + else: + raise TypeError("The type of parameter images is illegal.") diff --git a/diffsynth/extensions/QualityMetric/open_clip/__init__.py b/diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py similarity index 94% rename from diffsynth/extensions/QualityMetric/open_clip/__init__.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py index c328ed2..1560db0 100644 --- a/diffsynth/extensions/QualityMetric/open_clip/__init__.py +++ b/diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py @@ -9,6 +9,6 @@ from .openai import load_openai_model, list_openai_models from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub -from .tokenizer import SimpleTokenizer, tokenize, decode +from .tokenizer import SimpleTokenizer from .transform import image_transform, AugmentationCfg from .utils import freeze_batch_norm_2d diff --git a/diffsynth/extensions/QualityMetric/open_clip/coca_model.py b/diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/coca_model.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/constants.py b/diffsynth/extensions/ImageQualityMetric/open_clip/constants.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/constants.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/constants.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/factory.py b/diffsynth/extensions/ImageQualityMetric/open_clip/factory.py similarity index 98% rename from diffsynth/extensions/QualityMetric/open_clip/factory.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/factory.py index 00f0bb4..c353530 100644 --- a/diffsynth/extensions/QualityMetric/open_clip/factory.py +++ b/diffsynth/extensions/ImageQualityMetric/open_clip/factory.py @@ -18,7 +18,7 @@ from .loss import ClipLoss, DistillClipLoss, CoCaLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf from .transform import image_transform, AugmentationCfg -from .tokenizer import HFTokenizer, tokenize +from .tokenizer import HFTokenizer, SimpleTokenizer HF_HUB_PREFIX = 'hf-hub:' @@ -74,13 +74,13 @@ def get_model_config(model_name): return None -def get_tokenizer(model_name): +def get_tokenizer(model_name, open_clip_bpe_path=None): if model_name.startswith(HF_HUB_PREFIX): tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) else: config = get_model_config(model_name) tokenizer = HFTokenizer( - config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize + config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path) return tokenizer diff --git a/diffsynth/extensions/QualityMetric/open_clip/generation_utils.py b/diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/generation_utils.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/hf_configs.py b/diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/hf_configs.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/hf_model.py b/diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/hf_model.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/loss.py b/diffsynth/extensions/ImageQualityMetric/open_clip/loss.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/loss.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/loss.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/model.py b/diffsynth/extensions/ImageQualityMetric/open_clip/model.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/model.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/model.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/model_configs/ViT-H-14.json b/diffsynth/extensions/ImageQualityMetric/open_clip/model_configs/ViT-H-14.json similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/model_configs/ViT-H-14.json rename to diffsynth/extensions/ImageQualityMetric/open_clip/model_configs/ViT-H-14.json diff --git a/diffsynth/extensions/QualityMetric/open_clip/modified_resnet.py b/diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/modified_resnet.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/openai.py b/diffsynth/extensions/ImageQualityMetric/open_clip/openai.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/openai.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/openai.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/pretrained.py b/diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/pretrained.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/push_to_hf_hub.py b/diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/push_to_hf_hub.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/timm_model.py b/diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/timm_model.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/tokenizer.py b/diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py similarity index 83% rename from diffsynth/extensions/QualityMetric/open_clip/tokenizer.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py index ba00817..22ec488 100644 --- a/diffsynth/extensions/QualityMetric/open_clip/tokenizer.py +++ b/diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py @@ -151,44 +151,38 @@ class SimpleTokenizer(object): text = ''.join([self.decoder[token] for token in tokens]) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') return text + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length -_tokenizer = SimpleTokenizer() + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] -def decode(output_ids: torch.Tensor): - output_ids = output_ids.cpu().numpy() - return _tokenizer.decode(output_ids) + sot_token = self.encoder[""] + eot_token = self.encoder[""] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) -def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length + return result - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] - - sot_token = _tokenizer.encoder[""] - eot_token = _tokenizer.encoder[""] - all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - tokens = tokens[:context_length] # Truncate - tokens[-1] = eot_token - result[i, :len(tokens)] = torch.tensor(tokens) - - return result class HFTokenizer: diff --git a/diffsynth/extensions/QualityMetric/open_clip/transform.py b/diffsynth/extensions/ImageQualityMetric/open_clip/transform.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/transform.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/transform.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/transformer.py b/diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/transformer.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/utils.py b/diffsynth/extensions/ImageQualityMetric/open_clip/utils.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/utils.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/utils.py diff --git a/diffsynth/extensions/QualityMetric/open_clip/version.py b/diffsynth/extensions/ImageQualityMetric/open_clip/version.py similarity index 100% rename from diffsynth/extensions/QualityMetric/open_clip/version.py rename to diffsynth/extensions/ImageQualityMetric/open_clip/version.py diff --git a/diffsynth/extensions/QualityMetric/pickscore.py b/diffsynth/extensions/ImageQualityMetric/pickscore.py similarity index 98% rename from diffsynth/extensions/QualityMetric/pickscore.py rename to diffsynth/extensions/ImageQualityMetric/pickscore.py index b289c57..7370e09 100644 --- a/diffsynth/extensions/QualityMetric/pickscore.py +++ b/diffsynth/extensions/ImageQualityMetric/pickscore.py @@ -5,8 +5,9 @@ from typing import List, Union import os from .config import MODEL_PATHS -class PickScore: +class PickScore(torch.nn.Module): def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS): + super().__init__() """Initialize the Selector with a processor and model. Args: @@ -53,6 +54,7 @@ class PickScore: return score.cpu().item() + @torch.no_grad() def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]: """Score the images based on the prompt. diff --git a/diffsynth/extensions/QualityMetric/trainer/__init__.py b/diffsynth/extensions/ImageQualityMetric/trainer/__init__.py similarity index 100% rename from diffsynth/extensions/QualityMetric/trainer/__init__.py rename to diffsynth/extensions/ImageQualityMetric/trainer/__init__.py diff --git a/diffsynth/extensions/QualityMetric/trainer/models/__init__.py b/diffsynth/extensions/ImageQualityMetric/trainer/models/__init__.py similarity index 100% rename from diffsynth/extensions/QualityMetric/trainer/models/__init__.py rename to diffsynth/extensions/ImageQualityMetric/trainer/models/__init__.py diff --git a/diffsynth/extensions/QualityMetric/trainer/models/base_model.py b/diffsynth/extensions/ImageQualityMetric/trainer/models/base_model.py similarity index 100% rename from diffsynth/extensions/QualityMetric/trainer/models/base_model.py rename to diffsynth/extensions/ImageQualityMetric/trainer/models/base_model.py diff --git a/diffsynth/extensions/QualityMetric/trainer/models/clip_model.py b/diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py similarity index 92% rename from diffsynth/extensions/QualityMetric/trainer/models/clip_model.py rename to diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py index 4b61a4d..0a1b370 100644 --- a/diffsynth/extensions/QualityMetric/trainer/models/clip_model.py +++ b/diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py @@ -12,7 +12,7 @@ import torch from .cross_modeling import Cross_model -import gc +import json, os class XCLIPModel(HFCLIPModel): def __init__(self, config: CLIPConfig): @@ -96,9 +96,15 @@ class ClipModelConfig(BaseModelConfig): class CLIPModel(nn.Module): - def __init__(self, ckpt): + def __init__(self, ckpt, config_file=False): super().__init__() - self.model = XCLIPModel.from_pretrained(ckpt) + if config_file is None: + self.model = XCLIPModel.from_pretrained(ckpt) + else: + with open(os.path.join(ckpt, "config.json"), "r", encoding="utf-8") as f: + config = json.load(f) + config = CLIPConfig(**config) + self.model = XCLIPModel._from_config(config) self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16) def get_text_features(self, *args, **kwargs): diff --git a/diffsynth/extensions/QualityMetric/trainer/models/cross_modeling.py b/diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py similarity index 96% rename from diffsynth/extensions/QualityMetric/trainer/models/cross_modeling.py rename to diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py index d9f1fd8..938f1b7 100644 --- a/diffsynth/extensions/QualityMetric/trainer/models/cross_modeling.py +++ b/diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py @@ -1,292 +1,292 @@ -import torch -from torch import einsum, nn -import torch.nn.functional as F -from einops import rearrange, repeat - -# helper functions - -def exists(val): - return val is not None - -def default(val, d): - return val if exists(val) else d - -# normalization -# they use layernorm without bias, something that pytorch does not offer - - -class LayerNorm(nn.Module): - def __init__(self, dim): - super().__init__() - self.weight = nn.Parameter(torch.ones(dim)) - self.register_buffer("bias", torch.zeros(dim)) - - def forward(self, x): - return F.layer_norm(x, x.shape[-1:], self.weight, self.bias) - -# residual - - -class Residual(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - - def forward(self, x, *args, **kwargs): - return self.fn(x, *args, **kwargs) + x - - -# rotary positional embedding -# https://arxiv.org/abs/2104.09864 - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) - - def forward(self, max_seq_len, *, device): - seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = einsum("i , j -> i j", seq, self.inv_freq) - return torch.cat((freqs, freqs), dim=-1) - - -def rotate_half(x): - x = rearrange(x, "... (j d) -> ... j d", j=2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(pos, t): - return (t * pos.cos()) + (rotate_half(t) * pos.sin()) - - -# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward -# https://arxiv.org/abs/2002.05202 - - -class SwiGLU(nn.Module): - def forward(self, x): - x, gate = x.chunk(2, dim=-1) - return F.silu(gate) * x - - -# parallel attention and feedforward with residual -# discovered by Wang et al + EleutherAI from GPT-J fame - -class ParallelTransformerBlock(nn.Module): - def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): - super().__init__() - self.norm = LayerNorm(dim) - - attn_inner_dim = dim_head * heads - ff_inner_dim = dim * ff_mult - self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) - - self.heads = heads - self.scale = dim_head**-0.5 - self.rotary_emb = RotaryEmbedding(dim_head) - - self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) - self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) - - self.ff_out = nn.Sequential( - SwiGLU(), - nn.Linear(ff_inner_dim, dim, bias=False) - ) - - self.register_buffer("pos_emb", None, persistent=False) - - - def get_rotary_embedding(self, n, device): - if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: - return self.pos_emb[:n] - - pos_emb = self.rotary_emb(n, device=device) - self.register_buffer("pos_emb", pos_emb, persistent=False) - return pos_emb - - def forward(self, x, attn_mask=None): - """ - einstein notation - b - batch - h - heads - n, i, j - sequence length (base sequence length, source, target) - d - feature dimension - """ - - n, device, h = x.shape[1], x.device, self.heads - - # pre layernorm - - x = self.norm(x) - - # attention queries, keys, values, and feedforward inner - - q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) - - # split heads - # they use multi-query single-key-value attention, yet another Noam Shazeer paper - # they found no performance loss past a certain scale, and more efficient decoding obviously - # https://arxiv.org/abs/1911.02150 - - q = rearrange(q, "b n (h d) -> b h n d", h=h) - - # rotary embeddings - - positions = self.get_rotary_embedding(n, device) - q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) - - # scale - - q = q * self.scale - - # similarity - - sim = einsum("b h i d, b j d -> b h i j", q, k) - - - # extra attention mask - for masking out attention from text CLS token to padding - - if exists(attn_mask): - attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j') - sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) - - # attention - - sim = sim - sim.amax(dim=-1, keepdim=True).detach() - attn = sim.softmax(dim=-1) - - # aggregate values - - out = einsum("b h i j, b j d -> b h i d", attn, v) - - # merge heads - - out = rearrange(out, "b h n d -> b n (h d)") - return self.attn_out(out) + self.ff_out(ff) - -# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward - -class CrossAttention(nn.Module): - def __init__( - self, - dim, - *, - context_dim=None, - dim_head=64, - heads=12, - parallel_ff=False, - ff_mult=4, - norm_context=False - ): - super().__init__() - self.heads = heads - self.scale = dim_head ** -0.5 - inner_dim = heads * dim_head - context_dim = default(context_dim, dim) - - self.norm = LayerNorm(dim) - self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) - self.to_out = nn.Linear(inner_dim, dim, bias=False) - - # whether to have parallel feedforward - - ff_inner_dim = ff_mult * dim - - self.ff = nn.Sequential( - nn.Linear(dim, ff_inner_dim * 2, bias=False), - SwiGLU(), - nn.Linear(ff_inner_dim, dim, bias=False) - ) if parallel_ff else None - - def forward(self, x, context, mask): - """ - einstein notation - b - batch - h - heads - n, i, j - sequence length (base sequence length, source, target) - d - feature dimension - """ - - # pre-layernorm, for queries and context - - x = self.norm(x) - context = self.context_norm(context) - - # get queries - - q = self.to_q(x) - q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) - - # scale - - q = q * self.scale - - # get key / values - - k, v = self.to_kv(context).chunk(2, dim=-1) - - # query / key similarity - - sim = einsum('b h i d, b j d -> b h i j', q, k) - - # attention - mask = mask.unsqueeze(1).repeat(1,self.heads,1,1) - sim = sim + mask # context mask - sim = sim - sim.amax(dim=-1, keepdim=True) - attn = sim.softmax(dim=-1) - - # aggregate - - out = einsum('b h i j, b j d -> b h i d', attn, v) - - # merge and combine heads - - out = rearrange(out, 'b h n d -> b n (h d)') - out = self.to_out(out) - - # add parallel feedforward (for multimodal layers) - - if exists(self.ff): - out = out + self.ff(x) - - return out - - -class Cross_model(nn.Module): - def __init__( - self, - dim=512, - layer_num=4, - dim_head=64, - heads=8, - ff_mult=4 - ): - super().__init__() - - self.layers = nn.ModuleList([]) - - - for ind in range(layer_num): - self.layers.append(nn.ModuleList([ - Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)), - Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)) - ])) - - def forward( - self, - query_tokens, - context_tokens, - mask - ): - - for cross_attn, self_attn_ff in self.layers: - query_tokens = cross_attn(query_tokens, context_tokens,mask) - query_tokens = self_attn_ff(query_tokens) - - return query_tokens +import torch +from torch import einsum, nn +import torch.nn.functional as F +from einops import rearrange, repeat + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# normalization +# they use layernorm without bias, something that pytorch does not offer + + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.register_buffer("bias", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.weight, self.bias) + +# residual + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +# rotary positional embedding +# https://arxiv.org/abs/2104.09864 + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, max_seq_len, *, device): + seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = einsum("i , j -> i j", seq, self.inv_freq) + return torch.cat((freqs, freqs), dim=-1) + + +def rotate_half(x): + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(pos, t): + return (t * pos.cos()) + (rotate_half(t) * pos.sin()) + + +# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward +# https://arxiv.org/abs/2002.05202 + + +class SwiGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +# parallel attention and feedforward with residual +# discovered by Wang et al + EleutherAI from GPT-J fame + +class ParallelTransformerBlock(nn.Module): + def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): + super().__init__() + self.norm = LayerNorm(dim) + + attn_inner_dim = dim_head * heads + ff_inner_dim = dim * ff_mult + self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) + + self.heads = heads + self.scale = dim_head**-0.5 + self.rotary_emb = RotaryEmbedding(dim_head) + + self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) + self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) + + self.ff_out = nn.Sequential( + SwiGLU(), + nn.Linear(ff_inner_dim, dim, bias=False) + ) + + self.register_buffer("pos_emb", None, persistent=False) + + + def get_rotary_embedding(self, n, device): + if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: + return self.pos_emb[:n] + + pos_emb = self.rotary_emb(n, device=device) + self.register_buffer("pos_emb", pos_emb, persistent=False) + return pos_emb + + def forward(self, x, attn_mask=None): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + n, device, h = x.shape[1], x.device, self.heads + + # pre layernorm + + x = self.norm(x) + + # attention queries, keys, values, and feedforward inner + + q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) + + # split heads + # they use multi-query single-key-value attention, yet another Noam Shazeer paper + # they found no performance loss past a certain scale, and more efficient decoding obviously + # https://arxiv.org/abs/1911.02150 + + q = rearrange(q, "b n (h d) -> b h n d", h=h) + + # rotary embeddings + + positions = self.get_rotary_embedding(n, device) + q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) + + # scale + + q = q * self.scale + + # similarity + + sim = einsum("b h i d, b j d -> b h i j", q, k) + + + # extra attention mask - for masking out attention from text CLS token to padding + + if exists(attn_mask): + attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j') + sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) + + # attention + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = einsum("b h i j, b j d -> b h i d", attn, v) + + # merge heads + + out = rearrange(out, "b h n d -> b n (h d)") + return self.attn_out(out) + self.ff_out(ff) + +# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + *, + context_dim=None, + dim_head=64, + heads=12, + parallel_ff=False, + ff_mult=4, + norm_context=False + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + context_dim = default(context_dim, dim) + + self.norm = LayerNorm(dim) + self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # whether to have parallel feedforward + + ff_inner_dim = ff_mult * dim + + self.ff = nn.Sequential( + nn.Linear(dim, ff_inner_dim * 2, bias=False), + SwiGLU(), + nn.Linear(ff_inner_dim, dim, bias=False) + ) if parallel_ff else None + + def forward(self, x, context, mask): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + # pre-layernorm, for queries and context + + x = self.norm(x) + context = self.context_norm(context) + + # get queries + + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) + + # scale + + q = q * self.scale + + # get key / values + + k, v = self.to_kv(context).chunk(2, dim=-1) + + # query / key similarity + + sim = einsum('b h i d, b j d -> b h i j', q, k) + + # attention + mask = mask.unsqueeze(1).repeat(1,self.heads,1,1) + sim = sim + mask # context mask + sim = sim - sim.amax(dim=-1, keepdim=True) + attn = sim.softmax(dim=-1) + + # aggregate + + out = einsum('b h i j, b j d -> b h i d', attn, v) + + # merge and combine heads + + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + + # add parallel feedforward (for multimodal layers) + + if exists(self.ff): + out = out + self.ff(x) + + return out + + +class Cross_model(nn.Module): + def __init__( + self, + dim=512, + layer_num=4, + dim_head=64, + heads=8, + ff_mult=4 + ): + super().__init__() + + self.layers = nn.ModuleList([]) + + + for ind in range(layer_num): + self.layers.append(nn.ModuleList([ + Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)), + Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)) + ])) + + def forward( + self, + query_tokens, + context_tokens, + mask + ): + + for cross_attn, self_attn_ff in self.layers: + query_tokens = cross_attn(query_tokens, context_tokens,mask) + query_tokens = self_attn_ff(query_tokens) + + return query_tokens diff --git a/diffsynth/extensions/QualityMetric/__init__.py b/diffsynth/extensions/QualityMetric/__init__.py deleted file mode 100644 index fe97f47..0000000 --- a/diffsynth/extensions/QualityMetric/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .aesthetic import * -from .clip import * -from .config import * -from .hps import * -from .imagereward import * -from .mps import * -from .pickscore import * \ No newline at end of file diff --git a/examples/QualityMetric/README.md b/examples/QualityMetric/README.md deleted file mode 100644 index 293e921..0000000 --- a/examples/QualityMetric/README.md +++ /dev/null @@ -1,49 +0,0 @@ -# Image Quality Metric - -The image quality assessment functionality has now been integrated into Diffsynth. - -## Usage - -### Step 1: Download pretrained reward models - -``` -modelscope download --model 'DiffSynth-Studio/QualityMetric_reward_pretrained' -``` - -The file directory is shown below. - -``` -DiffSynth-Studio/ -└── models/ - └── QualityMetric/ - ├── HPS_v2/ - │ ├── HPS_v2_compressed.safetensors - │ ├── HPS_v2.1_compressed.safetensors - └── ... -``` - -### Step 2: Test image quality metric - -Prompt: "a painting of an ocean with clouds and birds, day time, low depth field effect" - -|1.webp|2.webp|3.webp|4.webp| -|-|-|-|-| -|![0](images/1.webp)|![1](images/2.webp)|![2](images/3.webp)|![3](images/4.webp)| - - - -``` -CUDA_VISIBLE_DEVICES=0 python testreward.py -``` - -### Output: - -``` -ImageReward: [0.5811904668807983, 0.2745198607444763, -1.4158903360366821, -2.032487154006958] -Aesthetic [5.900862693786621, 5.776571273803711, 5.799864292144775, 5.05204963684082] -PickScore: [0.20737126469612122, 0.20443597435951233, 0.20660750567913055, 0.19426065683364868] -CLIPScore: [0.3894640803337097, 0.3544551134109497, 0.33861416578292847, 0.32878392934799194] -HPScorev2: [0.2672519087791443, 0.25495243072509766, 0.24888549745082855, 0.24302822351455688] -HPScorev21: [0.2321144938468933, 0.20233657956123352, 0.1978294551372528, 0.19230154156684875] -MPS_score: [10.921875, 10.71875, 10.578125, 9.25] -``` diff --git a/examples/QualityMetric/images/1.webp b/examples/QualityMetric/images/1.webp deleted file mode 100644 index dc4eb5b..0000000 Binary files a/examples/QualityMetric/images/1.webp and /dev/null differ diff --git a/examples/QualityMetric/images/2.webp b/examples/QualityMetric/images/2.webp deleted file mode 100644 index ce63393..0000000 Binary files a/examples/QualityMetric/images/2.webp and /dev/null differ diff --git a/examples/QualityMetric/images/3.webp b/examples/QualityMetric/images/3.webp deleted file mode 100644 index eb2c966..0000000 Binary files a/examples/QualityMetric/images/3.webp and /dev/null differ diff --git a/examples/QualityMetric/images/4.webp b/examples/QualityMetric/images/4.webp deleted file mode 100644 index 73bda58..0000000 Binary files a/examples/QualityMetric/images/4.webp and /dev/null differ diff --git a/examples/QualityMetric/testreward.py b/examples/QualityMetric/testreward.py deleted file mode 100644 index 23c84f4..0000000 --- a/examples/QualityMetric/testreward.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -import torch -from PIL import Image -from diffsynth.extensions.QualityMetric.imagereward import ImageRewardScore -from diffsynth.extensions.QualityMetric.pickscore import PickScore -from diffsynth.extensions.QualityMetric.aesthetic import AestheticScore -from diffsynth.extensions.QualityMetric.clip import CLIPScore -from diffsynth.extensions.QualityMetric.hps import HPScore_v2 -from diffsynth.extensions.QualityMetric.mps import MPScore - -# download model from modelscope -from modelscope.hub.snapshot_download import snapshot_download - -current_dir = os.path.dirname(os.path.abspath(__file__)) -project_root = os.path.abspath(os.path.join(current_dir, '../../')) -model_folder = os.path.join(project_root, 'models', 'QualityMetric') - -# download HPS_v2 to your folder -# model_id = "DiffSynth-Studio/QualityMetric_reward_pretrained" -# downloaded_path = snapshot_download( -# model_id, -# cache_dir=os.path.join(model_folder, 'HPS_v2'), -# allow_patterns=["HPS_v2/*"], -# ) - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -def get_model_path(model_folder, model_name): - return os.path.join(model_folder, model_name) - -# your model path -model_path = { - "aesthetic_predictor": get_model_path(model_folder, "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"), - "open_clip": get_model_path(model_folder, "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"), - "hpsv2": get_model_path(model_folder, "HPS_v2/HPS_v2_compressed.safetensors"), - "hpsv2.1": get_model_path(model_folder, "HPS_v2/HPS_v2.1_compressed.safetensors"), - "imagereward": get_model_path(model_folder, "ImageReward/ImageReward.safetensors"), - "med_config": get_model_path(model_folder, "ImageReward/med_config.json"), - "clip": get_model_path(model_folder, "CLIP-ViT-H-14-laion2B-s32B-b79K"), - "clip-large": get_model_path(model_folder, "clip-vit-large-patch14"), - "mps": get_model_path(model_folder, "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"), - "pickscore": get_model_path(model_folder, "PickScore_v1") -} - -# load reward models -mps_score = MPScore(device,path = model_path) -image_reward = ImageRewardScore(device, path = model_path) -aesthetic_score = AestheticScore(device, path = model_path) -pick_score = PickScore(device, path = model_path) -clip_score = CLIPScore(device, path = model_path) -hps_score = HPScore_v2(device, path = model_path, model_version = 'v2') -hps2_score = HPScore_v2(device, path = model_path, model_version = 'v21') - -prompt = "a painting of an ocean with clouds and birds, day time, low depth field effect" -img_prefix = "images" -generations = [f"{pic_id}.webp" for pic_id in range(1, 5)] - -img_list = [Image.open(os.path.join(img_prefix, img)) for img in generations] -#img_list = [os.path.join(img_prefix, img) for img in generations] - -imre_scores = image_reward.score(img_list, prompt) -print("ImageReward:", imre_scores) - -aes_scores = aesthetic_score.score(img_list) -print("Aesthetic", aes_scores) - -p_scores = pick_score.score(img_list, prompt) -print("PickScore:", p_scores) - -c_scores = clip_score.score(img_list, prompt) -print("CLIPScore:", c_scores) - -h_scores = hps_score.score(img_list,prompt) -print("HPScorev2:", h_scores) - -h2_scores = hps2_score.score(img_list,prompt) -print("HPScorev21:", h2_scores) - -m_scores = mps_score.score(img_list, prompt) -print("MPS_score:", m_scores) \ No newline at end of file diff --git a/examples/image_quality_metric/README.md b/examples/image_quality_metric/README.md new file mode 100644 index 0000000..6d5bf8b --- /dev/null +++ b/examples/image_quality_metric/README.md @@ -0,0 +1,15 @@ +# Image Quality Metric + +The image quality assessment functionality has been integrated into Diffsynth. We support the following models: + +* [ImageReward](https://github.com/THUDM/ImageReward) +* [Aesthetic](https://github.com/christophschuhmann/improved-aesthetic-predictor) +* [PickScore](https://github.com/yuvalkirstain/pickscore) +* [CLIP](https://github.com/openai/CLIP) +* [HPSv2](https://github.com/tgxs002/HPSv2) +* [HPSv2.1](https://github.com/tgxs002/HPSv2) +* [MPS](https://github.com/Kwai-Kolors/MPS) + +## Usage + +See [`./image_quality_evaluation.py`](./image_quality_evaluation.py) for more details. diff --git a/examples/image_quality_metric/image_quality_evaluation.py b/examples/image_quality_metric/image_quality_evaluation.py new file mode 100644 index 0000000..910fa2c --- /dev/null +++ b/examples/image_quality_metric/image_quality_evaluation.py @@ -0,0 +1,23 @@ +from diffsynth.extensions.ImageQualityMetric import download_preference_model, load_preference_model +from modelscope import dataset_snapshot_download +from PIL import Image + + +# Download example image +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + allow_file_pattern="data/examples/ImageQualityMetric/image.jpg", + local_dir="./" +) + +# Parameters +prompt = "an orange cat" +image = Image.open("data\examples\ImageQualityMetric\image.jpg") +device = "cuda" +cache_dir = "./models" + +# Run preference models +for model_name in ["ImageReward", "Aesthetic", "PickScore", "CLIP", "HPSv2", "HPSv2.1", "MPS"]: + path = download_preference_model(model_name, cache_dir=cache_dir) + preference_model = load_preference_model(model_name, device=device, path=path) + print(model_name, preference_model.score(image, prompt)) diff --git a/models/QualityMetric/Put pretrained reward checkpoints here.txt b/models/QualityMetric/Put pretrained reward checkpoints here.txt deleted file mode 100644 index 8b13789..0000000 --- a/models/QualityMetric/Put pretrained reward checkpoints here.txt +++ /dev/null @@ -1 +0,0 @@ -