Compare commits

..

55 Commits

Author SHA1 Message Date
Zhongjie Duan
a66203a391 Update setup.py 2025-03-04 10:08:16 +08:00
Zhongjie Duan
fab61f614b Merge pull request #394 from modelscope/wan-train-update
fix swanlab after test
2025-03-03 19:00:48 +08:00
Artiprocher
6b67a11ad6 fix swanlab after test 2025-03-03 18:59:34 +08:00
Zhongjie Duan
91f77d268c Merge pull request #393 from modelscope/wan-train-update
support resume training
2025-03-03 18:45:17 +08:00
Artiprocher
eb4d5187d8 support resume training 2025-03-03 18:31:31 +08:00
Zhongjie Duan
ee4b02247c Merge pull request #392 from modelscope/sage_attention
Sage attention
2025-03-03 14:28:36 +08:00
Artiprocher
da8e1fe7e4 support sage attention 2025-03-03 14:19:16 +08:00
Zhongjie Duan
3db824c281 Merge pull request #390 from YunhongLu-ZJU/main
revised image quality metric
2025-03-03 13:36:34 +08:00
YunhongLu-ZJU
df2ecafd3f revised 2025-03-03 12:30:26 +08:00
Zhongjie Duan
217652d28e Merge pull request #389 from modelscope/requirements
Requirements
2025-03-03 11:25:31 +08:00
Artiprocher
f64c766dcd update install guide in README 2025-03-03 11:24:48 +08:00
Artiprocher
076fd85556 update install guide in README 2025-03-03 11:10:51 +08:00
Zhongjie Duan
c7912ed827 Merge pull request #388 from modelscope/preference_model
Preference model
2025-03-02 19:56:00 +08:00
Artiprocher
e63f9d6993 update preference models 2025-03-02 19:52:27 +08:00
Raffaele Mancuso
d80ef3a677 Sentencepiece requires cmake 2025-03-02 10:58:42 +01:00
philipy1219
852c3d831f support sageattn 2025-03-02 15:09:21 +08:00
Zhongjie Duan
ceb92ee7aa Merge pull request #378 from modelscope/wan-video-params
update wan input params
2025-02-28 19:52:20 +08:00
Artiprocher
3a75026176 update wan input params 2025-02-28 19:43:18 +08:00
Zhongjie Duan
6a92b08244 Merge pull request #375 from modelscope/swanlab-dev
del swanlab because of bad cases
2025-02-28 16:16:56 +08:00
Zhongjie Duan
38bc785ea9 Merge branch 'main' into swanlab-dev 2025-02-28 16:16:15 +08:00
Artiprocher
a466fdca8f del swanlab 2025-02-28 16:13:06 +08:00
Zhongjie Duan
f9f49e3c78 Merge pull request #374 from modelscope/wan-tokenizer-bugfix
align wan tokenizer to official
2025-02-28 16:05:36 +08:00
Artiprocher
61a30673c2 align wan tokenizer to official 2025-02-28 15:50:07 +08:00
Yingda Chen
a48822ec00 Merge pull request #372 from Zeyi-Lin/main
fix: text-to-image swanlab_logger
2025-02-28 14:38:36 +08:00
ZeYi Lin
b6c3d2b74a fix: logger 2025-02-28 12:51:58 +08:00
Zhongjie Duan
5006c2176c Merge pull request #371 from modelscope/wan-video-readme
Update README.md
2025-02-28 10:10:03 +08:00
Zhongjie Duan
d3d3556ff6 Update README.md 2025-02-28 10:09:48 +08:00
Zhongjie Duan
6fa8dbe077 Merge pull request #366 from modelscope/swanlab
Swanlab
2025-02-27 19:32:23 +08:00
Artiprocher
a57749ef60 update swanlab log 2025-02-27 19:30:53 +08:00
Artiprocher
b5c1d33e58 update swanlab log 2025-02-27 19:21:51 +08:00
Zhongjie Duan
34a9f82865 Merge pull request #365 from modelscope/wan-train-dev
update wanx lora examples
2025-02-27 19:07:10 +08:00
Artiprocher
18dc6cb962 update wanx lora examples 2025-02-27 19:06:24 +08:00
Zhongjie Duan
c760208614 Merge pull request #360 from modelscope/wan-train-dev
support wan image training
2025-02-27 12:58:32 +08:00
Artiprocher
fad7aea58a support wan image training 2025-02-27 12:56:55 +08:00
Zhongjie Duan
b42eb1444c Merge pull request #357 from modelscope/bugfix
bugfix
2025-02-27 11:06:24 +08:00
Zhongjie Duan
25a247dd3f bugfix 2025-02-27 11:06:10 +08:00
Zhongjie Duan
7792017a02 Update README.md 2025-02-27 10:52:47 +08:00
Zhongjie Duan
0219e8d2f3 Update README.md 2025-02-26 22:53:07 +08:00
Zhongjie Duan
1d309a14a3 Merge pull request #352 from modelscope/bugfix
Fix Wan VAE device
2025-02-26 20:03:53 +08:00
Zhongjie Duan
7df73ceaaf Fix Wan VAE device 2025-02-26 20:03:26 +08:00
ZeYi Lin
1419bec53d feat: add swanlab logger 2025-02-26 17:12:54 +08:00
Zhongjie Duan
cf12723c89 Merge pull request #347 from co63oc/fix1
Fix typos
2025-02-26 15:50:36 +08:00
co63oc
4268f5466b Fix 2025-02-26 14:18:36 +08:00
Zhongjie Duan
b9f5a00d98 Merge pull request #345 from ghunkins/dev/ghunkins/allow-for-py39
🐍 Remove Python 3.10 Type Hint
2025-02-26 11:42:19 +08:00
Zhongjie Duan
7d44dc99fb support wan full training
support wan full train
2025-02-26 11:38:51 +08:00
Artiprocher
b20de1b44d support wan full train 2025-02-26 11:34:04 +08:00
Gregory D. Hunkins
366ee0f542 remove py310 type hint 2025-02-25 22:29:53 -05:00
Artiprocher
bed770248b update examples 2025-02-26 10:25:36 +08:00
Kohaku-Blueleaf
020560d2b5 Fix num_frames in i2v (#339)
* Fix num_frames in i2v

* Remove print in flash_attention
2025-02-26 10:05:51 +08:00
Zhongjie Duan
af7d305f00 Wan video (#338) 2025-02-25 19:00:43 +08:00
YunhongLu-ZJU
4449faaa01 Merge branch 'modelscope:main' into main 2025-02-17 14:45:13 +08:00
YunhongLu-ZJU
991ba162bd add new quality metric 2025-02-17 14:42:20 +08:00
YunhongLu-ZJU
77d0f4d297 add image quality metric 2025-02-14 14:02:17 +08:00
YunhongLu-ZJU
a834371d50 add quality metric 2025-02-14 13:59:56 +08:00
hongzhang.hz
acda7d891a image quality metric 2025-02-14 12:39:06 +08:00
72 changed files with 9984 additions and 118 deletions

View File

@@ -17,6 +17,7 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu
Until now, DiffSynth Studio has supported the following models:
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo)
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
@@ -36,7 +37,9 @@ Until now, DiffSynth Studio has supported the following models:
## News
- **February 17, 2024** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
- **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
@@ -118,12 +121,19 @@ cd DiffSynth-Studio
pip install -e .
```
Or install from pypi:
Or install from pypi (There is a delay in the update. If you want to experience the latest features, please do not use this installation method.):
```
pip install diffsynth
```
If you encounter issues during installation, it may be caused by the packages we depend on. Please refer to the documentation of the package that caused the problem.
* [torch](https://pytorch.org/get-started/locally/)
* [sentencepiece](https://github.com/google/sentencepiece)
* [cmake](https://cmake.org)
* [cupy](https://docs.cupy.dev/en/stable/install.html)
## Usage (in Python code)
The Python examples are in [`examples`](./examples/). We provide an overview here.

View File

@@ -1,7 +1,7 @@
# Set web page format
import streamlit as st
st.set_page_config(layout="wide")
# Diasble virtual VRAM on windows system
# Disable virtual VRAM on windows system
import torch
torch.cuda.set_per_process_memory_fraction(0.999, 0)

View File

@@ -54,7 +54,11 @@ from ..models.hunyuan_video_dit import HunyuanVideoDiT
from ..models.stepvideo_vae import StepVideoVAE
from ..models.stepvideo_dit import StepVideoModel
from ..models.wanx_vae import WanXVideoVAE
from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vae import WanVideoVAE
model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -109,7 +113,13 @@ model_loader_configs = [
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wanxvideo_vae"], [WanXVideoVAE], "civitai")
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.

View File

@@ -0,0 +1 @@
from .blip_pretrain import *

View File

@@ -0,0 +1,77 @@
'''
* Adapted from BLIP (https://github.com/salesforce/BLIP)
'''
import warnings
warnings.filterwarnings("ignore")
import torch
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
from transformers import BertTokenizer
from .vit import VisionTransformer, interpolate_pos_embed
def default_bert():
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
model_path = os.path.join(project_root, 'models', 'QualityMetric')
return os.path.join(model_path, "bert-base-uncased")
def init_tokenizer(bert_model_path):
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
return tokenizer
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
assert vit in ['base', 'large'], "vit parameter must be base or large"
if vit=='base':
vision_width = 768
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0 or drop_path_rate
)
elif vit=='large':
vision_width = 1024
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0.1 or drop_path_rate
)
return visual_encoder, vision_width
def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def load_checkpoint(model,url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
else:
raise RuntimeError('checkpoint url or path is invalid')
state_dict = checkpoint['model']
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
model.visual_encoder_m)
for key in model.state_dict().keys():
if key in state_dict.keys():
if state_dict[key].shape!=model.state_dict()[key].shape:
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
del state_dict[key]
msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
return model,msg

View File

@@ -0,0 +1,44 @@
'''
* Adapted from BLIP (https://github.com/salesforce/BLIP)
'''
import transformers
transformers.logging.set_verbosity_error()
from torch import nn
import os
from .med import BertConfig, BertModel
from .blip import create_vit, init_tokenizer
class BLIP_Pretrain(nn.Module):
def __init__(self,
med_config = "med_config.json",
image_size = 224,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
embed_dim = 256,
queue_size = 57600,
momentum = 0.995,
bert_model_path = ""
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
self.tokenizer = init_tokenizer(bert_model_path)
encoder_config = BertConfig.from_json_file(med_config)
encoder_config.encoder_width = vision_width
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
text_width = self.text_encoder.config.hidden_size
self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)

View File

@@ -0,0 +1,947 @@
'''
* Adapted from BLIP (https://github.com/salesforce/BLIP)
* Based on huggingface code base
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
'''
import math
from typing import Tuple
import torch
from torch import Tensor, device, nn
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.file_utils import (
ModelOutput,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from transformers.modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from transformers.utils import logging
from transformers.models.bert.configuration_bert import BertConfig
logger = logging.get_logger(__name__)
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
def forward(
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config, is_cross_attention):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_width, self.all_head_size)
self.value = nn.Linear(config.encoder_width, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.save_attention = False
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
attention_probs.register_hook(self.save_attn_gradients)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
outputs = outputs + (past_key_value,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
self.self = BertSelfAttention(config, is_cross_attention)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config, layer_num):
super().__init__()
self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.layer_num = layer_num
if self.config.add_cross_attention:
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
mode=None,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if mode=='multimodal':
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
mode='multimodal',
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
mode=mode,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
mode=mode,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
causal_mask,
],
axis=-1,
)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
mode='multimodal',
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
device = input_ids.device
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = inputs_embeds.device
elif encoder_embeds is not None:
input_shape = encoder_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = encoder_embeds.device
else:
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
device, is_decoder)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if type(encoder_hidden_states) == list:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
else:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if encoder_embeds is None:
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
else:
embedding_output = encoder_embeds
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
mode=mode,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_logits=False,
is_decoder=True,
reduction='mean',
mode='multimodal',
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
Returns:
Example::
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
>>> config = BertConfig.from_pretrained("bert-base-cased")
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
mode=mode,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores[:, :-1, :].contiguous()
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if reduction=='none':
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
"is_decoder": True,
}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past

View File

@@ -0,0 +1,301 @@
'''
* Adapted from BLIP (https://github.com/salesforce/BLIP)
* Based on timm code base
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.vision_transformer import _cfg, PatchEmbed
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, DropPath
from timm.models.helpers import named_apply, adapt_input_conv
# from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.attn_gradients = None
self.attention_map = None
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def forward(self, x, register_hook=False):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
# if use_grad_checkpointing:
# self.attn = checkpoint_wrapper(self.attn)
# self.mlp = checkpoint_wrapper(self.mlp)
def forward(self, x, register_hook=False):
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
use_grad_checkpointing=False, ckpt_layer=0):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer
"""
super().__init__()
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward(self, x, register_blk=-1):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[:,:x.size(1),:]
x = self.pos_drop(x)
for i,blk in enumerate(self.blocks):
x = blk(x, register_blk==i)
x = self.norm(x)
return x
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=''):
_load_weights(self, checkpoint_path, prefix)
@torch.no_grad()
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
"""
import numpy as np
def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)
w = np.load(checkpoint_path)
if not prefix and 'opt/target/embedding/kernel' in w:
prefix = 'opt/target/'
if hasattr(model.patch_embed, 'backbone'):
# hybrid
backbone = model.patch_embed.backbone
stem_only = not hasattr(backbone, 'stem')
stem = backbone if stem_only else backbone.stem
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
if not stem_only:
for i, stage in enumerate(backbone.stages):
for j, block in enumerate(stage.blocks):
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
for r in range(3):
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
if block.downsample is not None:
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
else:
embed_conv_w = adapt_input_conv(
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
# interpolate position embedding
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = visual_encoder.patch_embed.num_patches
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
if orig_size!=new_size:
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
return new_pos_embed
else:
return pos_embed_checkpoint

View File

@@ -0,0 +1,148 @@
from modelscope import snapshot_download
from typing_extensions import Literal, TypeAlias
import os
from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore
from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore
from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore
from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore
from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2
from diffsynth.extensions.ImageQualityMetric.mps import MPScore
preference_model_id: TypeAlias = Literal[
"ImageReward",
"Aesthetic",
"PickScore",
"CLIP",
"HPSv2",
"HPSv2.1",
"MPS",
]
model_dict = {
"ImageReward": {
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
"allow_file_pattern": [
"ImageReward/ImageReward.safetensors",
"ImageReward/med_config.json",
"bert-base-uncased/config.json",
"bert-base-uncased/model.safetensors",
"bert-base-uncased/tokenizer.json",
"bert-base-uncased/tokenizer_config.json",
"bert-base-uncased/vocab.txt",
],
"load_path": {
"imagereward": "ImageReward/ImageReward.safetensors",
"med_config": "ImageReward/med_config.json",
"bert_model_path": "bert-base-uncased",
},
"model_class": ImageRewardScore
},
"Aesthetic": {
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
"allow_file_pattern": [
"aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
"clip-vit-large-patch14/config.json",
"clip-vit-large-patch14/merges.txt",
"clip-vit-large-patch14/model.safetensors",
"clip-vit-large-patch14/preprocessor_config.json",
"clip-vit-large-patch14/special_tokens_map.json",
"clip-vit-large-patch14/tokenizer.json",
"clip-vit-large-patch14/tokenizer_config.json",
"clip-vit-large-patch14/vocab.json",
],
"load_path": {
"aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
"clip-large": "clip-vit-large-patch14",
},
"model_class": AestheticScore
},
"PickScore": {
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
"allow_file_pattern": [
"PickScore_v1/*",
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
],
"load_path": {
"pickscore": "PickScore_v1",
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
},
"model_class": PickScore
},
"CLIP": {
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
"allow_file_pattern": [
"CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
"bpe_simple_vocab_16e6.txt.gz",
],
"load_path": {
"open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
},
"model_class": CLIPScore
},
"HPSv2": {
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
"allow_file_pattern": [
"HPS_v2/HPS_v2_compressed.safetensors",
"bpe_simple_vocab_16e6.txt.gz",
],
"load_path": {
"hpsv2": "HPS_v2/HPS_v2_compressed.safetensors",
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
},
"model_class": HPScore_v2,
"extra_kwargs": {"model_version": "v2"}
},
"HPSv2.1": {
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
"allow_file_pattern": [
"HPS_v2/HPS_v2.1_compressed.safetensors",
"bpe_simple_vocab_16e6.txt.gz",
],
"load_path": {
"hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors",
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
},
"model_class": HPScore_v2,
"extra_kwargs": {"model_version": "v21"}
},
"MPS": {
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
"allow_file_pattern": [
"MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
],
"load_path": {
"mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
},
"model_class": MPScore
},
}
def download_preference_model(model_name: preference_model_id, cache_dir="models"):
metadata = model_dict[model_name]
snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir)
load_path = metadata["load_path"]
load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()}
return load_path
def load_preference_model(model_name: preference_model_id, device = "cuda", path = None):
model_class = model_dict[model_name]["model_class"]
extra_kwargs = model_dict[model_name].get("extra_kwargs", {})
preference_model = model_class(device=device, path=path, **extra_kwargs)
return preference_model

View File

@@ -0,0 +1,148 @@
from typing import List, Optional
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModel
from safetensors.torch import load_file
import os
from typing import Union, List
from .config import MODEL_PATHS
class MLP(torch.nn.Module):
def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "avg_rating"):
super().__init__()
self.input_size = input_size
self.xcol = xcol
self.ycol = ycol
self.layers = torch.nn.Sequential(
torch.nn.Linear(self.input_size, 1024),
#torch.nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(1024, 128),
#torch.nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(128, 64),
#torch.nn.ReLU(),
torch.nn.Dropout(0.1),
torch.nn.Linear(64, 16),
#torch.nn.ReLU(),
torch.nn.Linear(16, 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
x = batch[self.xcol]
y = batch[self.ycol].reshape(-1, 1)
x_hat = self.layers(x)
loss = torch.nn.functional.mse_loss(x_hat, y)
return loss
def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
x = batch[self.xcol]
y = batch[self.ycol].reshape(-1, 1)
x_hat = self.layers(x)
loss = torch.nn.functional.mse_loss(x_hat, y)
return loss
def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.Adam(self.parameters(), lr=1e-3)
class AestheticScore(torch.nn.Module):
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
super().__init__()
self.device = device
self.aes_model_path = path.get("aesthetic_predictor")
# Load the MLP model
self.model = MLP(768)
try:
if self.aes_model_path.endswith(".safetensors"):
state_dict = load_file(self.aes_model_path)
else:
state_dict = torch.load(self.aes_model_path)
self.model.load_state_dict(state_dict)
except Exception as e:
raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
self.model.to(device)
self.model.eval()
# Load the CLIP model and processor
clip_model_name = path.get('clip-large')
self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
self.processor = AutoProcessor.from_pretrained(clip_model_name)
def _calculate_score(self, image: torch.Tensor) -> float:
"""Calculate the aesthetic score for a single image.
Args:
image (torch.Tensor): The processed image tensor.
Returns:
float: The aesthetic score.
"""
with torch.no_grad():
# Get image embeddings
image_embs = self.model2.get_image_features(image)
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
# Compute score
score = self.model(image_embs).cpu().flatten().item()
return score
@torch.no_grad()
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
"""Score the images based on their aesthetic quality.
Args:
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
Returns:
List[float]: List of scores for the images.
"""
try:
if isinstance(images, (str, Image.Image)):
# Single image
if isinstance(images, str):
pil_image = Image.open(images)
else:
pil_image = images
# Prepare image inputs
image_inputs = self.processor(
images=pil_image,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(self.device)
return [self._calculate_score(image_inputs["pixel_values"])]
elif isinstance(images, list):
# Multiple images
scores = []
for one_image in images:
if isinstance(one_image, str):
pil_image = Image.open(one_image)
elif isinstance(one_image, Image.Image):
pil_image = one_image
else:
raise TypeError("The type of parameter images is illegal.")
# Prepare image inputs
image_inputs = self.processor(
images=pil_image,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(self.device)
scores.append(self._calculate_score(image_inputs["pixel_values"]))
return scores
else:
raise TypeError("The type of parameter images is illegal.")
except Exception as e:
raise RuntimeError(f"Error in scoring images: {e}")

View File

@@ -0,0 +1,97 @@
from typing import List, Union
from PIL import Image
import torch
from .open_clip import create_model_and_transforms, get_tokenizer
from .config import MODEL_PATHS
class CLIPScore(torch.nn.Module):
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
super().__init__()
"""Initialize the CLIPScore with a model and tokenizer.
Args:
device (torch.device): The device to load the model on.
"""
self.device = device
# Create model and transforms
self.model, _, self.preprocess_val = create_model_and_transforms(
"ViT-H-14",
# "laion2B-s32B-b79K",
pretrained=path.get("open_clip"),
precision="amp",
device=device,
jit=False,
force_quick_gelu=False,
force_custom_text=False,
force_patch_dropout=False,
force_image_size=None,
pretrained_image=False,
image_mean=None,
image_std=None,
light_augmentation=True,
aug_cfg={},
output_dict=True,
with_score_predictor=False,
with_region_predictor=False,
)
# Initialize tokenizer
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
self.model = self.model.to(device)
self.model.eval()
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
"""Calculate the CLIP score for a single image and prompt.
Args:
image (torch.Tensor): The processed image tensor.
prompt (str): The prompt text.
Returns:
float: The CLIP score.
"""
with torch.no_grad():
# Process the prompt
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
# Calculate the CLIP score
outputs = self.model(image, text)
image_features, text_features = outputs["image_features"], outputs["text_features"]
logits_per_image = image_features @ text_features.T
clip_score = torch.diagonal(logits_per_image).cpu().numpy()
return clip_score[0].item()
@torch.no_grad()
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
"""Score the images based on the prompt.
Args:
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
prompt (str): The prompt text.
Returns:
List[float]: List of CLIP scores for the images.
"""
if isinstance(images, (str, Image.Image)):
# Single image
if isinstance(images, str):
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
else:
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
return [self._calculate_score(image, prompt)]
elif isinstance(images, list):
# Multiple images
scores = []
for one_images in images:
if isinstance(one_images, str):
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
elif isinstance(one_images, Image.Image):
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
else:
raise TypeError("The type of parameter images is illegal.")
scores.append(self._calculate_score(image, prompt))
return scores
else:
raise TypeError("The type of parameter images is illegal.")

View File

@@ -0,0 +1,23 @@
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
model_path = os.path.join(project_root, 'models', 'QualityMetric')
def get_model_path(model_name):
return os.path.join(model_path, model_name)
MODEL_PATHS = {
"aesthetic_predictor": get_model_path("aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
"open_clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
"hpsv2": get_model_path("HPS_v2/HPS_v2_compressed.safetensors"),
"hpsv2.1": get_model_path("HPS_v2/HPS_v2.1_compressed.safetensors"),
"imagereward": get_model_path("ImageReward/ImageReward.safetensors"),
"med_config": get_model_path("ImageReward/med_config.json"),
"clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
"clip-large": get_model_path("clip-vit-large-patch14"),
"mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
"pickscore": get_model_path("PickScore_v1")
}

View File

@@ -0,0 +1,118 @@
from typing import List, Union
from PIL import Image
import torch
from .open_clip import create_model_and_transforms, get_tokenizer
from safetensors.torch import load_file
import os
from .config import MODEL_PATHS
class HPScore_v2(torch.nn.Module):
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
super().__init__()
"""Initialize the Selector with a model and tokenizer.
Args:
device (torch.device): The device to load the model on.
model_version (str): The version of the model to load. Supports "v2" or "v21". Default is "v2".
"""
self.device = device
if model_version == "v2":
safetensors_path = path.get("hpsv2")
elif model_version == "v21":
safetensors_path = path.get("hpsv2.1")
else:
raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
# Create model and transforms
model, _, self.preprocess_val = create_model_and_transforms(
"ViT-H-14",
# "laion2B-s32B-b79K",
pretrained=path.get("open_clip"),
precision="amp",
device=device,
jit=False,
force_quick_gelu=False,
force_custom_text=False,
force_patch_dropout=False,
force_image_size=None,
pretrained_image=False,
image_mean=None,
image_std=None,
light_augmentation=True,
aug_cfg={},
output_dict=True,
with_score_predictor=False,
with_region_predictor=False,
)
# Load model weights
try:
state_dict = load_file(safetensors_path)
model.load_state_dict(state_dict)
except Exception as e:
raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
# Initialize tokenizer and model
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
model = model.to(device)
model.eval()
self.model = model
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
"""Calculate the HPS score for a single image and prompt.
Args:
image (torch.Tensor): The processed image tensor.
prompt (str): The prompt text.
Returns:
float: The HPS score.
"""
with torch.no_grad():
# Process the prompt
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
# Calculate the HPS score
outputs = self.model(image, text)
image_features, text_features = outputs["image_features"], outputs["text_features"]
logits_per_image = image_features @ text_features.T
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
return hps_score[0].item()
@torch.no_grad()
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
"""Score the images based on the prompt.
Args:
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
prompt (str): The prompt text.
Returns:
List[float]: List of HPS scores for the images.
"""
try:
if isinstance(images, (str, Image.Image)):
# Single image
if isinstance(images, str):
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
else:
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
return [self._calculate_score(image, prompt)]
elif isinstance(images, list):
# Multiple images
scores = []
for one_images in images:
if isinstance(one_images, str):
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
elif isinstance(one_images, Image.Image):
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
else:
raise TypeError("The type of parameter images is illegal.")
scores.append(self._calculate_score(image, prompt))
return scores
else:
raise TypeError("The type of parameter images is illegal.")
except Exception as e:
raise RuntimeError(f"Error in scoring images: {e}")

View File

@@ -0,0 +1,212 @@
import os
import torch
from PIL import Image
from typing import List, Union
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from .BLIP.blip_pretrain import BLIP_Pretrain
from torchvision.transforms import InterpolationMode
from safetensors.torch import load_file
from .config import MODEL_PATHS
BICUBIC = InterpolationMode.BICUBIC
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
class MLP(torch.nn.Module):
def __init__(self, input_size):
super().__init__()
self.input_size = input_size
self.layers = torch.nn.Sequential(
torch.nn.Linear(self.input_size, 1024),
#nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(1024, 128),
#nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(128, 64),
#nn.ReLU(),
torch.nn.Dropout(0.1),
torch.nn.Linear(64, 16),
#nn.ReLU(),
torch.nn.Linear(16, 1)
)
# initial MLP param
for name, param in self.layers.named_parameters():
if 'weight' in name:
torch.nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
if 'bias' in name:
torch.nn.init.constant_(param, val=0)
def forward(self, input):
return self.layers(input)
class ImageReward(torch.nn.Module):
def __init__(self, med_config, device='cpu', bert_model_path=""):
super().__init__()
self.device = device
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
self.preprocess = _transform(224)
self.mlp = MLP(768)
self.mean = 0.16717362830052426
self.std = 1.0333394966054072
def score_grad(self, prompt_ids, prompt_attention_mask, image):
"""Calculate the score with gradient for a single image and prompt.
Args:
prompt_ids (torch.Tensor): Tokenized prompt IDs.
prompt_attention_mask (torch.Tensor): Attention mask for the prompt.
image (torch.Tensor): The processed image tensor.
Returns:
torch.Tensor: The reward score.
"""
image_embeds = self.blip.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
text_output = self.blip.text_encoder(
prompt_ids,
attention_mask=prompt_attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
txt_features = text_output.last_hidden_state[:, 0, :]
rewards = self.mlp(txt_features)
rewards = (rewards - self.mean) / self.std
return rewards
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
"""Score the images based on the prompt.
Args:
prompt (str): The prompt text.
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
Returns:
List[float]: List of scores for the images.
"""
if isinstance(images, (str, Image.Image)):
# Single image
if isinstance(images, str):
pil_image = Image.open(images)
else:
pil_image = images
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
return [self._calculate_score(prompt, image).item()]
elif isinstance(images, list):
# Multiple images
scores = []
for one_image in images:
if isinstance(one_image, str):
pil_image = Image.open(one_image)
elif isinstance(one_image, Image.Image):
pil_image = one_image
else:
raise TypeError("The type of parameter images is illegal.")
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
scores.append(self._calculate_score(prompt, image).item())
return scores
else:
raise TypeError("The type of parameter images is illegal.")
def _calculate_score(self, prompt: str, image: torch.Tensor) -> torch.Tensor:
"""Calculate the score for a single image and prompt.
Args:
prompt (str): The prompt text.
image (torch.Tensor): The processed image tensor.
Returns:
torch.Tensor: The reward score.
"""
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
image_embeds = self.blip.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
text_output = self.blip.text_encoder(
text_input.input_ids,
attention_mask=text_input.attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
txt_features = text_output.last_hidden_state[:, 0, :].float()
rewards = self.mlp(txt_features)
rewards = (rewards - self.mean) / self.std
return rewards
def inference_rank(self, prompt: str, generations_list: List[Union[str, Image.Image]]) -> tuple:
"""Rank the images based on the prompt.
Args:
prompt (str): The prompt text.
generations_list (List[Union[str, Image.Image]]): List of image paths or PIL images.
Returns:
tuple: (indices, rewards) where indices are the ranks and rewards are the scores.
"""
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
txt_set = []
for generation in generations_list:
if isinstance(generation, str):
pil_image = Image.open(generation)
elif isinstance(generation, Image.Image):
pil_image = generation
else:
raise TypeError("The type of parameter generations_list is illegal.")
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
image_embeds = self.blip.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
text_output = self.blip.text_encoder(
text_input.input_ids,
attention_mask=text_input.attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
txt_set.append(text_output.last_hidden_state[:, 0, :])
txt_features = torch.cat(txt_set, 0).float()
rewards = self.mlp(txt_features)
rewards = (rewards - self.mean) / self.std
rewards = torch.squeeze(rewards)
_, rank = torch.sort(rewards, dim=0, descending=True)
_, indices = torch.sort(rank, dim=0)
indices = indices + 1
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
class ImageRewardScore(torch.nn.Module):
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
super().__init__()
self.device = device if isinstance(device, torch.device) else torch.device(device)
model_path = path.get("imagereward")
med_config = path.get("med_config")
state_dict = load_file(model_path)
self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
self.model.load_state_dict(state_dict, strict=False)
self.model.eval()
@torch.no_grad()
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
"""Score the images based on the prompt.
Args:
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
prompt (str): The prompt text.
Returns:
List[float]: List of scores for the images.
"""
return self.model.score(images, prompt)

View File

@@ -0,0 +1,129 @@
import numpy as np
import torch
from PIL import Image
from io import BytesIO
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
from transformers import CLIPConfig
from dataclasses import dataclass
from transformers import CLIPModel as HFCLIPModel
from safetensors.torch import load_file
from torch import nn, einsum
from .trainer.models.base_model import BaseModelConfig
from transformers import CLIPConfig
from transformers import AutoProcessor, AutoModel, AutoTokenizer
from typing import Any, Optional, Tuple, Union, List
import torch
from .trainer.models.cross_modeling import Cross_model
from .trainer.models import clip_model
import torch.nn.functional as F
import gc
import json
from .config import MODEL_PATHS
class MPScore(torch.nn.Module):
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
super().__init__()
"""Initialize the MPSModel with a processor, tokenizer, and model.
Args:
device (Union[str, torch.device]): The device to load the model on.
"""
self.device = device
processor_name_or_path = path.get("clip")
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True)
state_dict = load_file(path.get("mps"))
self.model.load_state_dict(state_dict, strict=False)
self.model.to(device)
self.condition = condition
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
"""Calculate the reward score for a single image and prompt.
Args:
image (torch.Tensor): The processed image tensor.
prompt (str): The prompt text.
Returns:
float: The reward score.
"""
def _tokenize(caption):
input_ids = self.tokenizer(
caption,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
return input_ids
text_input = _tokenize(prompt).to(self.device)
if self.condition == 'overall':
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things'
elif self.condition == 'aesthetics':
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry'
elif self.condition == 'quality':
condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture'
elif self.condition == 'semantic':
condition_prompt = 'quantity, attributes, position, number, location'
else:
raise ValueError(
f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.")
condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device)
with torch.no_grad():
text_f, text_features = self.model.model.get_text_features(text_input)
image_f = self.model.model.get_image_features(image.half())
condition_f, _ = self.model.model.get_text_features(condition_batch)
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
sim_text_condition = sim_text_condition / sim_text_condition.max()
mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
mask = mask.repeat(1, image_f.shape[1], 1)
image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :]
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
image_score = self.model.logit_scale.exp() * text_features @ image_features.T
return image_score[0].cpu().numpy().item()
@torch.no_grad()
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
"""Score the images based on the prompt.
Args:
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
prompt (str): The prompt text.
Returns:
List[float]: List of reward scores for the images.
"""
if isinstance(images, (str, Image.Image)):
# Single image
if isinstance(images, str):
image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device)
else:
image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device)
return [self._calculate_score(image, prompt)]
elif isinstance(images, list):
# Multiple images
scores = []
for one_images in images:
if isinstance(one_images, str):
image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device)
elif isinstance(one_images, Image.Image):
image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device)
else:
raise TypeError("The type of parameter images is illegal.")
scores.append(self._calculate_score(image, prompt))
return scores
else:
raise TypeError("The type of parameter images is illegal.")

View File

@@ -0,0 +1,14 @@
from .coca_model import CoCa
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
from .tokenizer import SimpleTokenizer
from .transform import image_transform, AugmentationCfg
from .utils import freeze_batch_norm_2d

View File

@@ -0,0 +1,458 @@
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass
from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
MultimodalTransformer,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
try:
from transformers import (
BeamSearchScorer,
LogitsProcessorList,
TopPLogitsWarper,
TopKLogitsWarper,
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MaxLengthCriteria,
StoppingCriteriaList
)
GENERATION_TYPES = {
"top_k": TopKLogitsWarper,
"top_p": TopPLogitsWarper,
"beam_search": "beam_search"
}
_has_transformers = True
except ImportError as e:
GENERATION_TYPES = {
"top_k": None,
"top_p": None,
"beam_search": "beam_search"
}
_has_transformers = False
@dataclass
class MultimodalCfg(CLIPTextCfg):
mlp_ratio: int = 4
dim_head: int = 64
heads: int = 8
n_queries: int = 256
attn_pooler_heads: int = 8
def _build_text_decoder_tower(
embed_dim,
multimodal_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)
decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
ls_init_value=multimodal_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return decoder
class CoCa(nn.Module):
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
super().__init__()
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
self.text = _build_text_tower(
embed_dim=embed_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
vocab_size = (
text_cfg.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else text_cfg.vocab_size
)
self.visual = _build_vision_tower(
embed_dim=embed_dim,
vision_cfg=vision_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.text_decoder = _build_text_decoder_tower(
vocab_size,
multimodal_cfg=multimodal_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.pad_id = pad_id
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
self.text_decoder.set_grad_checkpointing(enable)
def _encode_image(self, images, normalize=True):
image_latent, tokens_embs = self.visual(images)
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return image_latent, tokens_embs
def _encode_text(self, text, normalize=True, embed_cls=True):
text = text[:, :-1] if embed_cls else text # make space for CLS token
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb
def encode_image(self, images, normalize=True):
image_latent, _ = self._encode_image(images, normalize=normalize)
return image_latent
def encode_text(self, text, normalize=True, embed_cls=True):
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
return text_latent
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)
# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]
logits = self.text_decoder(image_embs, token_embs)
return {
"image_features": image_latent,
"text_features": text_latent,
"logits": logits,
"labels": labels,
"logit_scale": self.logit_scale.exp()
}
def generate(
self,
image,
text=None,
seq_len=30,
max_seq_len=77,
temperature=1.,
generation_type="beam_search",
top_p=0.1, # keep tokens in the 1 - top_p quantile
top_k=1, # keeps the top_k most probable tokens
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
repetition_penalty=1.0,
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
):
# taking many ideas and components from HuggingFace GenerationMixin
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
with torch.no_grad():
sot_token_id = 49406 if sot_token_id is None else sot_token_id
eos_token_id = 49407 if eos_token_id is None else eos_token_id
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
logit_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
]
)
if stopping_criteria is None:
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
stopping_criteria = StoppingCriteriaList(
stopping_criteria
)
device = image.device
if generation_type == "beam_search":
output = self._generate_beamsearch(
image_inputs = image,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
sot_token_id=sot_token_id,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
min_seq_len=min_seq_len,
stopping_criteria=stopping_criteria,
logit_processor=logit_processor,
)
if fixed_output_length and output.shape[1] < seq_len:
return torch.cat(
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
dim=1
)
return output
elif generation_type == "top_p":
logit_warper = GENERATION_TYPES[generation_type](top_p)
elif generation_type == "top_k":
logit_warper = GENERATION_TYPES[generation_type](top_k)
else:
raise ValueError(
f"generation_type has to be one of "
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
)
image_latent, image_embs = self._encode_image(image)
if text is None:
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
was_training = self.training
num_dims = len(text.shape)
if num_dims == 1:
text = text[None, :]
cur_len = text.shape[1]
self.eval()
out = text
while True:
x = out[:, -max_seq_len:]
cur_len = x.shape[1]
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
if mask.all():
if not fixed_output_length:
break
else:
logits = logits[~mask, :]
filtered_logits = logit_processor(x[~mask, :], logits)
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
probs = F.softmax(filtered_logits / temperature, dim=-1)
if (cur_len + 1 == seq_len):
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
else:
sample[~mask, :] = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
cur_len += 1
if stopping_criteria(out, None):
break
if num_dims == 1:
out = out.squeeze(0)
self.train(was_training)
return out
def _generate_beamsearch(
self,
image_inputs,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
logit_processor=None,
logit_warper=None,
):
device = image_inputs.device
batch_size = image_inputs.shape[0]
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
image_latent, image_embs = self._encode_image(image_inputs)
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
input_ids = input_ids * sot_token_id
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=device,
num_beam_groups=num_beam_groups,
)
# instantiate logits processors
logits_processor = (
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
if logit_processor is None
else logit_processor
)
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_beam_size, cur_len = input_ids.shape
beam_indices = None
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))
while True:
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
# do one decoder step on all beams of all sentences in batch
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
outputs = self(
model_inputs['images'],
model_inputs['text'],
embed_cls=False,
image_latent=image_latent,
image_embs=image_embs
)
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx
# indices of beams of current group among all sentences in batch
batch_group_indices = []
for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of currentg group only
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
vocab_size = next_token_logits.shape[-1]
next_token_scores_processed = logits_processor(
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
)
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, None):
break
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
return sequence_outputs['sequences']
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = None
return {
"text": input_ids,
"images": image_inputs,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask,
}

View File

@@ -0,0 +1,2 @@
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)

View File

@@ -0,0 +1,433 @@
import json
import logging
import os
import pathlib
import re
from copy import deepcopy
from pathlib import Path
from turtle import forward
from typing import Any, Dict, Optional, Tuple, Union
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .coca_model import CoCa
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform, AugmentationCfg
from .tokenizer import HFTokenizer, SimpleTokenizer
HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = ('.json',)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f'*{ext}'))
for cf in config_files:
with open(cf, 'r') as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
_rescan_model_configs() # initial populate of model config registry
def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
def get_tokenizer(model_name, open_clip_bpe_path=None):
if model_name.startswith(HF_HUB_PREFIX):
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
else:
config = get_model_config(model_name)
tokenizer = HFTokenizer(
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path)
return tokenizer
def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith('module'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
return state_dict
def load_checkpoint(model, checkpoint_path, strict=True):
state_dict = load_state_dict(checkpoint_path)
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
resize_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
):
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
model_id = model_name[len(HF_HUB_PREFIX):]
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
pretrained_cfg = config['preprocess_cfg']
model_cfg = config['model_cfg']
else:
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
checkpoint_path = None
pretrained_cfg = {}
model_cfg = None
if isinstance(device, str):
device = torch.device(device)
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(
model_name,
precision=precision,
device=device,
jit=jit,
cache_dir=cache_dir,
)
# to always output dict even if it is clip
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
else:
model_cfg = model_cfg or get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
if force_image_size is not None:
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size
if pretrained_image:
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
else:
assert False, 'pretrained image towers currently only supported for timm models'
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
if custom_text:
if is_hf_model:
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
if "coca" in model_name:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
pretrained_loaded = False
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)
pretrained_loaded = True
elif has_hf_hub_prefix:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
pretrained_loaded = True
if require_pretrained and not pretrained_loaded:
# callers of create_model_from_pretrained always expect pretrained weights
raise RuntimeError(
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
model.to(device=device)
if precision in ("fp16", "bf16"):
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
# to always output dict even if it is clip
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
if jit:
model = torch.jit.script(model)
return model
def create_loss(args):
if args.distill:
return DistillClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif "coca" in args.model.lower():
return CoCaLoss(
caption_loss_weight=args.coca_caption_loss_weight,
clip_loss_weight=args.coca_contrastive_loss_weight,
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
class MLP(torch.nn.Module):
def __init__(self, input_size):
super().__init__()
self.input_size = input_size
self.layers = torch.nn.Sequential(
torch.nn.Linear(self.input_size, 1024),
torch.nn.Dropout(0.2),
torch.nn.Linear(1024, 128),
torch.nn.Dropout(0.2),
torch.nn.Linear(128, 64),
torch.nn.Dropout(0.1),
torch.nn.Linear(64, 16),
torch.nn.Linear(16, 1)
)
def forward(self, x):
return self.layers(x)
# class semantic_head(torch.nn.Module):
# def __init__(self, input_size):
# super().__init__()
# self.input_size = input_size # for ViT-L-14 is 1024
# self.seg_head = torch.nn.Sequential(
# torch.nn.Linear(input_size, 128),
# torch.nn.Dropout(0.2),
# torch.nn.Linear(128, 64),
# torch.nn.Dropout(0.1),
# torch.nn.Linear(64, 16),
# torch.nn.Linear(16, 1),
# )
# self.sigmoid = torch.nn.Sigmoid()
# def forward(self, x):
# return self.sigmoid(self.seg_head(x))
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
cache_dir: Optional[str] = None,
light_augmentation = False,
output_dict: Optional[bool] = None,
with_score_predictor: bool = False,
with_region_predictor: bool = False
):
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_patch_dropout=force_patch_dropout,
force_image_size=force_image_size,
pretrained_image=pretrained_image,
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
)
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
if with_score_predictor:
model.score_predictor = MLP(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
if with_region_predictor:
# model.region_predictor = semantic_head(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
model.region_predictor = torch.nn.Linear(model.visual.proj.size(0), 1).to(device=device, dtype=model.visual.proj.dtype)
# preprocess_train = image_transform_region(
# model.visual.image_size,
# is_train=True,
# mean=image_mean,
# std=image_std
# )
# preprocess_val = image_transform_region(
# model.visual.image_size,
# is_train=False,
# mean=image_mean,
# std=image_std
# )
if light_augmentation:
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std,
resize_longest_max=True,
)
preprocess_train = preprocess_val
else:
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)
return model, preprocess_train, preprocess_val
def create_model_from_pretrained(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
return_transform: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
):
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_image_size=force_image_size,
cache_dir=cache_dir,
require_pretrained=True,
)
if not return_transform:
return model
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std,
)
return model, preprocess

View File

@@ -0,0 +1,45 @@
# HF architecture dict:
arch_dict = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
"roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
"xlm-roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
"mt5": {
"config_names": {
# unlimited seqlen
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
"context_length": "",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "num_heads",
"layers": "num_layers",
"layer_attr": "block",
"token_embeddings_attr": "embed_tokens"
},
"pooler": "mean_pooler",
},
}

View File

@@ -0,0 +1,176 @@
""" huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
import torch
import torch.nn as nn
from torch import TensorType
try:
import transformers
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
BaseModelOutputWithPoolingAndCrossAttentions
except ImportError as e:
transformers = None
class BaseModelOutput:
pass
class PretrainedConfig:
pass
from .hf_configs import arch_dict
# utils
def _camel2snake(s):
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
# TODO: ?last - for gpt-like models
_POOLERS = {}
def register_pooler(cls):
"""Decorator registering pooler class"""
_POOLERS[_camel2snake(cls.__name__)] = cls
return cls
@register_pooler
class MeanPooler(nn.Module):
"""Mean pooling"""
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
@register_pooler
class MaxPooler(nn.Module):
"""Max pooling"""
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
return masked_output.max(1).values
@register_pooler
class ClsPooler(nn.Module):
"""CLS token pooling"""
def __init__(self, use_pooler_output=True):
super().__init__()
self.cls_token_position = 0
self.use_pooler_output = use_pooler_output
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
if (self.use_pooler_output and
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
(x.pooler_output is not None)
):
return x.pooler_output
return x.last_hidden_state[:, self.cls_token_position, :]
class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""
output_tokens: torch.jit.Final[bool]
def __init__(
self,
model_name_or_path: str,
output_dim: int,
config: PretrainedConfig = None,
pooler_type: str = None,
proj: str = None,
pretrained: bool = True,
output_tokens: bool = False,
):
super().__init__()
self.output_tokens = output_tokens
self.output_dim = output_dim
# TODO: find better way to get this information
uses_transformer_pooler = (pooler_type == "cls_pooler")
if transformers is None:
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
if config is None:
self.config = AutoConfig.from_pretrained(model_name_or_path)
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
AutoModel.from_config, self.config)
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = create_func(model_args)
self.transformer = self.transformer.encoder
else:
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
else:
self.config = config
self.transformer = AutoModel.from_config(config)
if pooler_type is None: # get default arch pooler
pooler_type = (arch_dict[self.config.model_type]["pooler"])
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj is None): # do we always need a proj?
self.proj = nn.Identity()
elif proj == 'linear':
self.proj = nn.Linear(d_model, output_dim, bias=False)
elif proj == 'mlp':
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=False),
)
def forward(self, x: TensorType):
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)
projected = self.proj(pooled_out)
seq_len = out.last_hidden_state.shape[1]
tokens = (
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
if type(self.pooler) == ClsPooler
else out.last_hidden_state
)
if self.output_tokens:
return projected, tokens
return projected
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
if not unlocked_layers: # full freezing
for n, p in self.transformer.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
return
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
embeddings = getattr(
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
modules = [embeddings, *layer_list][:-unlocked_layers]
# freeze layers
for module in modules:
for n, p in module.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.gradient_checkpointing_enable()
def init_parameters(self):
pass

View File

@@ -0,0 +1,270 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
try:
import torch.distributed.nn
from torch import distributed as dist
has_distributed = True
except ImportError:
has_distributed = False
try:
import horovod.torch as hvd
except ImportError:
hvd = None
def gather_features(
image_features,
text_features,
local_loss=False,
gather_with_grad=False,
rank=0,
world_size=1,
use_horovod=False
):
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
if use_horovod:
assert hvd is not None, 'Please install horovod'
if gather_with_grad:
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
else:
with torch.no_grad():
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
else:
# We gather tensors from all gpus
if gather_with_grad:
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
else:
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
return all_image_features, all_text_features
class ClipLoss(nn.Module):
def __init__(
self,
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__()
self.local_loss = local_loss
self.gather_with_grad = gather_with_grad
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod
# cache state
self.prev_num_logits = 0
self.labels = {}
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
# calculated ground-truth and cache if enabled
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
labels = labels + num_logits * self.rank
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
return labels
def get_logits(self, image_features, text_features, logit_scale):
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
else:
logits_per_image = logit_scale * all_image_features @ all_text_features.T
logits_per_text = logits_per_image.T
else:
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T
return logits_per_image, logits_per_text
def forward(self, image_features, text_features, logit_scale, output_dict=False):
device = image_features.device
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
labels = self.get_ground_truth(device, logits_per_image.shape[0])
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
return total_loss
class PreferenceLoss(nn.Module):
def forward(self, logits_per_image, num_images, labels):
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
ce_loss = F.cross_entropy(paired_logits, labels)
return ce_loss
class HPSLoss(nn.Module):
def forward(self, text_logits, labels):
device = text_logits.device
text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
label_0, label_1 = labels.chunk(2, dim=-1)
index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
text_0_logits = text_0_logits[index, index]
text_1_logits = text_1_logits[index, index]
text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
text_1_labels = text_0_labels + 1
text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
text_loss = label_0 * text_0_loss + label_1 * text_1_loss
# absolute_example_weight = 1 / num_per_prompt
# denominator = absolute_example_weight.sum()
# weight_per_example = absolute_example_weight / denominator
# text_loss *= weight_per_example
text_loss = text_loss.sum()
return text_loss
class RankingLoss(nn.Module):
def forward(self, logits_per_image, num_images, labels, margin = 1.0):
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
label_list = [label for label in labels.split(num_images.tolist())]
# ranked_logits = [torch.index_select(paired_logits_list[i], 0, rank) for i, rank in enumerate(label_list)]
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-1)
padded_labels = pad_sequence(label_list, batch_first=True, padding_value=10)
# regulized_logits = torch.log(torch.sigmoid(paired_logits))
diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
# diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
# diff_label = torch.clamp(padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2), min=-1, max=1)
diff_label = - (padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2))
mask = torch.triu(torch.ones(diff.shape[1], diff.shape[1]), diagonal=1).bool().detach()
loss = torch.clamp(margin - torch.mul(diff[:, ~mask],diff_label[:,~mask]), min=0).mean()
return loss
class CoCaLoss(ClipLoss):
def __init__(
self,
caption_loss_weight,
clip_loss_weight,
pad_id=0, # pad_token for open_clip custom tokenizer
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__(
local_loss=local_loss,
gather_with_grad=gather_with_grad,
cache_labels=cache_labels,
rank=rank,
world_size=world_size,
use_horovod=use_horovod
)
self.clip_loss_weight = clip_loss_weight
self.caption_loss_weight = caption_loss_weight
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
clip_loss = super().forward(image_features, text_features, logit_scale)
clip_loss = self.clip_loss_weight * clip_loss
caption_loss = self.caption_loss(
logits.permute(0, 2, 1),
labels,
)
caption_loss = caption_loss * self.caption_loss_weight
if output_dict:
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
return clip_loss, caption_loss
class DistillClipLoss(ClipLoss):
def dist_loss(self, teacher_logits, student_logits):
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
def forward(
self,
image_features,
text_features,
logit_scale,
dist_image_features,
dist_text_features,
dist_logit_scale,
output_dict=False,
):
logits_per_image, logits_per_text = \
self.get_logits(image_features, text_features, logit_scale)
dist_logits_per_image, dist_logits_per_text = \
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
contrastive_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
distill_loss = (
self.dist_loss(dist_logits_per_image, logits_per_image) +
self.dist_loss(dist_logits_per_text, logits_per_text)
) / 2
if output_dict:
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
return contrastive_loss, distill_loss

View File

@@ -0,0 +1,461 @@
""" CLIP Model
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
from dataclasses import dataclass
import logging
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from .hf_model import HFTextEncoder
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
from .utils import to_2tuple
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
n_queries: int = 256 # n_queries for attentional pooler
attn_pooler_heads: int = 8 # n heads for attentional_pooling
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
timm_drop: float = 0. # head dropout
timm_drop_path: Optional[float] = None # backbone stochastic depth
output_tokens: bool = False
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
width: int = 512
heads: int = 8
layers: int = 12
ls_init_value: Optional[float] = None # layer scale initial value
hf_model_name: str = None
hf_tokenizer_name: str = None
hf_model_pretrained: bool = True
proj: str = 'mlp'
pooler_type: str = 'mean_pooler'
embed_cls: bool = False
pad_id: int = 0
output_tokens: bool = False
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.timm_model_name:
visual = TimmModel(
vision_cfg.timm_model_name,
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
proj_bias=vision_cfg.timm_proj_bias,
drop=vision_cfg.timm_drop,
drop_path=vision_cfg.timm_drop_path,
embed_dim=embed_dim,
image_size=vision_cfg.image_size,
)
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
heads=vision_heads,
image_size=vision_cfg.image_size,
width=vision_cfg.width,
)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
input_patchnorm=vision_cfg.input_patchnorm,
global_average_pool=vision_cfg.global_average_pool,
attentional_pool=vision_cfg.attentional_pool,
n_queries=vision_cfg.n_queries,
attn_pooler_heads=vision_cfg.attn_pooler_heads,
output_tokens=vision_cfg.output_tokens,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
if text_cfg.hf_model_name:
text = HFTextEncoder(
text_cfg.hf_model_name,
output_dim=embed_dim,
proj=text_cfg.proj,
pooler_type=text_cfg.pooler_type,
pretrained=text_cfg.hf_model_pretrained,
output_tokens=text_cfg.output_tokens,
)
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
embed_cls=text_cfg.embed_cls,
output_tokens=text_cfg.output_tokens,
pad_id=text_cfg.pad_id,
act_layer=act_layer,
norm_layer=norm_layer,
)
return text
class CLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
locked_layers = []
locked_layers.append(self.token_embedding)
self.positional_embedding.requires_grad = False
if unlocked_layers > 0:
locked_layers.append(self.transformer.resblocks[:-unlocked_layers])
else:
locked_layers.append(self.transformer)
locked_layers.append(self.ln_final)
self.text_projection.requires_grad = False
# freeze layers
for module in locked_layers:
for n, p in module.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
if self.output_dict:
return {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()
class CustomTextCLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
self.text.lock(unlocked_layers, freeze_layer_norm)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
if self.output_dict:
return {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
def _convert_weights(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(dtype)
if l.bias is not None:
l.bias.data = l.bias.data.to(dtype)
if isinstance(l, (nn.MultiheadAttention, Attention)):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.to(dtype)
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.to(dtype)
model.apply(_convert_weights)
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
if 'text_projection' in state_dict:
# old format state_dict, move text tower -> .text
new_state_dict = {}
for k, v in state_dict.items():
if any(k.startswith(p) for p in (
'text_projection',
'positional_embedding',
'token_embedding',
'transformer',
'ln_final',
)):
k = 'text.' + k
new_state_dict[k] = v
return new_state_dict
return state_dict
def build_model_from_openai_state_dict(
state_dict: dict,
quick_gelu=True,
cast_dtype=torch.float16,
):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_size = vision_patch_size * grid_size
else:
counts: list = [
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_size = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
vision_cfg = CLIPVisionCfg(
layers=vision_layers,
width=vision_width,
patch_size=vision_patch_size,
image_size=image_size,
)
text_cfg = CLIPTextCfg(
context_length=context_length,
vocab_size=vocab_size,
width=transformer_width,
heads=transformer_heads,
layers=transformer_layers,
)
model = CLIP(
embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
cast_dtype=cast_dtype,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
model.load_state_dict(state_dict)
return model.eval()
def trace_model(model, batch_size=256, device=torch.device('cpu')):
model.eval()
image_size = model.visual.image_size
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
model = torch.jit.trace_module(
model,
inputs=dict(
forward=(example_images, example_text),
encode_text=(example_text,),
encode_image=(example_images,)
))
model.visual.image_size = image_size
return model
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get('visual.positional_embedding', None)
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict['visual.positional_embedding'] = new_pos_embed

View File

@@ -0,0 +1,17 @@
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}

View File

@@ -0,0 +1,181 @@
from collections import OrderedDict
import torch
from torch import nn
from torch.nn import functional as F
from .utils import freeze_batch_norm_2d
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.act2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.act3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.act1(self.bn1(self.conv1(x)))
out = self.act2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.act3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
super().__init__()
self.output_dim = output_dim
self.image_size = image_size
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.act2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.act3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
self.init_parameters()
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def init_parameters(self):
if self.attnpool is not None:
std = self.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
for param in self.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
# FIXME support for non-transformer
pass
def stem(self, x):
x = self.act1(self.bn1(self.conv1(x)))
x = self.act2(self.bn2(self.conv2(x)))
x = self.act3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x

View File

@@ -0,0 +1,144 @@
""" OpenAI pretrained model functions
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import os
import warnings
from typing import List, Optional, Union
import torch
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
__all__ = ["list_openai_models", "load_openai_model"]
def list_openai_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list_pretrained_models_by_tag('openai')
def load_openai_model(
name: str,
precision: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
jit: bool = True,
cache_dir: Optional[str] = None,
):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
precision: str
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
cache_dir : Optional[str]
The directory to cache the downloaded model weights
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if precision is None:
precision = 'fp32' if device == 'cpu' else 'fp16'
if get_pretrained_url(name, 'openai'):
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
# Build a non-jit model from the OpenAI jitted model state dict
cast_dtype = get_cast_dtype(precision)
try:
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
except KeyError:
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
model = model.to(device)
if precision.startswith('amp') or precision == 'fp32':
model.float()
elif precision == 'bf16':
convert_weights_to_lp(model, dtype=torch.bfloat16)
return model
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 (typically for CPU)
if precision == 'fp32':
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
# ensure image_size attr available at consistent location for both jit and non-jit
model.visual.image_size = model.input_resolution.item()
return model

View File

@@ -0,0 +1,376 @@
import hashlib
import os
import urllib
import warnings
from functools import partial
from typing import Dict, Union
from tqdm import tqdm
from .version import __version__
try:
from huggingface_hub import hf_hub_download
hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
_has_hf_hub = True
except ImportError:
hf_hub_download = None
_has_hf_hub = False
def _pcfg(url='', hf_hub='', mean=None, std=None):
return dict(
url=url,
hf_hub=hf_hub,
mean=mean,
std=std,
)
_RN50 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
cc12m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
)
_RN50_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
cc12m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
)
_RN101 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
)
_RN101_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
)
_RN50x4 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
)
_RN50x16 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
)
_RN50x64 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
)
_VITB32 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
laion2b_e16=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
)
_VITB32_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
)
_VITB16 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
# laion400m_32k=_pcfg(
# url="",
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# laion400m_64k=_pcfg(
# url="",
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
)
_VITB16_PLUS_240 = dict(
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
)
_VITL14 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
laion2b_s32b_b82k=_pcfg(
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
)
_VITL14_336 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
)
_VITH14 = dict(
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
)
_VITg14 = dict(
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
)
_VITbigG14 = dict(
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
)
_robertaViTB32 = dict(
laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
)
_xlmRobertaBaseViTB32 = dict(
laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
)
_xlmRobertaLargeFrozenViTH14 = dict(
frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
)
_convnext_base = dict(
laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
)
_convnext_base_w = dict(
laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
)
_convnext_base_w_320 = dict(
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
)
_convnext_large_d = dict(
laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
)
_convnext_large_d_320 = dict(
laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
)
_convnext_xxlarge = dict(
laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
)
_coca_VITB32 = dict(
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
)
_coca_VITL14 = dict(
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
)
_PRETRAINED = {
"RN50": _RN50,
"RN50-quickgelu": _RN50_quickgelu,
"RN101": _RN101,
"RN101-quickgelu": _RN101_quickgelu,
"RN50x4": _RN50x4,
"RN50x16": _RN50x16,
"RN50x64": _RN50x64,
"ViT-B-32": _VITB32,
"ViT-B-32-quickgelu": _VITB32_quickgelu,
"ViT-B-16": _VITB16,
"ViT-B-16-plus-240": _VITB16_PLUS_240,
"ViT-L-14": _VITL14,
"ViT-L-14-336": _VITL14_336,
"ViT-H-14": _VITH14,
"ViT-g-14": _VITg14,
"ViT-bigG-14": _VITbigG14,
"roberta-ViT-B-32": _robertaViTB32,
"xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
"xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
"convnext_base": _convnext_base,
"convnext_base_w": _convnext_base_w,
"convnext_base_w_320": _convnext_base_w_320,
"convnext_large_d": _convnext_large_d,
"convnext_large_d_320": _convnext_large_d_320,
"convnext_xxlarge": _convnext_xxlarge,
"coca_ViT-B-32": _coca_VITB32,
"coca_ViT-L-14": _coca_VITL14,
}
def _clean_tag(tag: str):
# normalize pretrained tags
return tag.lower().replace('-', '_')
def list_pretrained(as_str: bool = False):
""" returns list of pretrained models
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
"""
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
def list_pretrained_models_by_tag(tag: str):
""" return all models having the specified pretrain tag """
models = []
tag = _clean_tag(tag)
for k in _PRETRAINED.keys():
if tag in _PRETRAINED[k]:
models.append(k)
return models
def list_pretrained_tags_by_model(model: str):
""" return all pretrain tags for the specified model architecture """
tags = []
if model in _PRETRAINED:
tags.extend(_PRETRAINED[model].keys())
return tags
def is_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return False
return _clean_tag(tag) in _PRETRAINED[model]
def get_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return {}
model_pretrained = _PRETRAINED[model]
return model_pretrained.get(_clean_tag(tag), {})
def get_pretrained_url(model: str, tag: str):
cfg = get_pretrained_cfg(model, _clean_tag(tag))
return cfg.get('url', '')
def download_pretrained_from_url(
url: str,
cache_dir: Union[str, None] = None,
):
if not cache_dir:
cache_dir = os.path.expanduser("~/.cache/clip")
os.makedirs(cache_dir, exist_ok=True)
filename = os.path.basename(url)
if 'openaipublic' in url:
expected_sha256 = url.split("/")[-2]
elif 'mlfoundations' in url:
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
else:
expected_sha256 = ''
download_target = os.path.join(cache_dir, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if expected_sha256:
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
else:
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return _has_hf_hub
def download_pretrained_from_hf(
model_id: str,
filename: str = 'open_clip_pytorch_model.bin',
revision=None,
cache_dir: Union[str, None] = None,
):
has_hf_hub(True)
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
return cached_file
def download_pretrained(
cfg: Dict,
force_hf_hub: bool = False,
cache_dir: Union[str, None] = None,
):
target = ''
if not cfg:
return target
download_url = cfg.get('url', '')
download_hf_hub = cfg.get('hf_hub', '')
if download_hf_hub and force_hf_hub:
# use HF hub even if url exists
download_url = ''
if download_url:
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
elif download_hf_hub:
has_hf_hub(True)
# we assume the hf_hub entries in pretrained config combine model_id + filename in
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
model_id, filename = os.path.split(download_hf_hub)
if filename:
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
else:
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
return target

View File

@@ -0,0 +1,243 @@
import argparse
import json
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Tuple
import torch
try:
from huggingface_hub import (
create_repo,
get_hf_file_metadata,
hf_hub_download,
hf_hub_url,
repo_type_and_id_from_hf_id,
upload_folder,
)
from huggingface_hub.utils import EntryNotFoundError
_has_hf_hub = True
except ImportError:
_has_hf_hub = False
from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
from .tokenizer import HFTokenizer
def save_config_for_hf(
model,
config_path: str,
model_config: Optional[dict]
):
preprocess_cfg = {
'mean': model.visual.image_mean,
'std': model.visual.image_std,
}
hf_config = {
'model_cfg': model_config,
'preprocess_cfg': preprocess_cfg,
}
with config_path.open('w') as f:
json.dump(hf_config, f, indent=2)
def save_for_hf(
model,
tokenizer: HFTokenizer,
model_config: dict,
save_directory: str,
weights_filename='open_clip_pytorch_model.bin',
config_filename='open_clip_config.json',
):
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / weights_filename
torch.save(model.state_dict(), weights_path)
tokenizer.save_pretrained(save_directory)
config_path = save_directory / config_filename
save_config_for_hf(model, config_path, model_config=model_config)
def push_to_hf_hub(
model,
tokenizer,
model_config: Optional[dict],
repo_id: str,
commit_message: str = 'Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_card: Optional[dict] = None,
):
if not isinstance(tokenizer, HFTokenizer):
# default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
tokenizer = HFTokenizer('openai/clip-vit-large-patch14')
# Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
# Infer complete repo_id from repo_url
# Can be different from the input `repo_id` if repo_owner was implicit
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
repo_id = f"{repo_owner}/{repo_name}"
# Check if README file already exist in repo
try:
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
has_readme = True
except EntryNotFoundError:
has_readme = False
# Dump model and push to Hub
with TemporaryDirectory() as tmpdir:
# Save model weights and config.
save_for_hf(
model,
tokenizer=tokenizer,
model_config=model_config,
save_directory=tmpdir,
)
# Add readme if it does not exist
if not has_readme:
model_card = model_card or {}
model_name = repo_id.split('/')[-1]
readme_path = Path(tmpdir) / "README.md"
readme_text = generate_readme(model_card, model_name)
readme_path.write_text(readme_text)
# Upload model and return
return upload_folder(
repo_id=repo_id,
folder_path=tmpdir,
revision=revision,
create_pr=create_pr,
commit_message=commit_message,
)
def push_pretrained_to_hf_hub(
model_name,
pretrained: str,
repo_id: str,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
commit_message: str = 'Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_card: Optional[dict] = None,
):
model, preprocess_eval = create_model_from_pretrained(
model_name,
pretrained=pretrained,
image_mean=image_mean,
image_std=image_std,
)
model_config = get_model_config(model_name)
assert model_config
tokenizer = get_tokenizer(model_name)
push_to_hf_hub(
model=model,
tokenizer=tokenizer,
model_config=model_config,
repo_id=repo_id,
commit_message=commit_message,
token=token,
revision=revision,
private=private,
create_pr=create_pr,
model_card=model_card,
)
def generate_readme(model_card: dict, model_name: str):
readme_text = "---\n"
readme_text += "tags:\n- zero-shot-image-classification\n- clip\n"
readme_text += "library_tag: open_clip\n"
readme_text += f"license: {model_card.get('license', 'mit')}\n"
if 'details' in model_card and 'Dataset' in model_card['details']:
readme_text += 'datasets:\n'
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n"
if 'description' in model_card:
readme_text += f"\n{model_card['description']}\n"
if 'details' in model_card:
readme_text += f"\n## Model Details\n"
for k, v in model_card['details'].items():
if isinstance(v, (list, tuple)):
readme_text += f"- **{k}:**\n"
for vi in v:
readme_text += f" - {vi}\n"
elif isinstance(v, dict):
readme_text += f"- **{k}:**\n"
for ki, vi in v.items():
readme_text += f" - {ki}: {vi}\n"
else:
readme_text += f"- **{k}:** {v}\n"
if 'usage' in model_card:
readme_text += f"\n## Model Usage\n"
readme_text += model_card['usage']
readme_text += '\n'
if 'comparison' in model_card:
readme_text += f"\n## Model Comparison\n"
readme_text += model_card['comparison']
readme_text += '\n'
if 'citation' in model_card:
readme_text += f"\n## Citation\n"
if not isinstance(model_card['citation'], (list, tuple)):
citations = [model_card['citation']]
else:
citations = model_card['citation']
for c in citations:
readme_text += f"```bibtex\n{c}\n```\n"
return readme_text
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
parser.add_argument(
"--model", type=str, help="Name of the model to use.",
)
parser.add_argument(
"--pretrained", type=str,
help="Use a pretrained CLIP model weights with the specified tag or file path.",
)
parser.add_argument(
"--repo-id", type=str,
help="Destination HF Hub repo-id ie 'organization/model_id'.",
)
parser.add_argument(
'--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override default image mean value of dataset')
parser.add_argument(
'--image-std', type=float, nargs='+', default=None, metavar='STD',
help='Override default image std deviation of of dataset')
args = parser.parse_args()
print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')
# FIXME add support to pass model_card json / template from file via cmd line
push_pretrained_to_hf_hub(
args.model,
args.pretrained,
args.repo_id,
image_mean=args.image_mean, # override image mean/std if trained w/ non defaults
image_std=args.image_std,
)
print(f'{args.model} saved.')

View File

@@ -0,0 +1,127 @@
""" timm model adapter
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
"""
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
try:
import timm
from timm.models.layers import Mlp, to_2tuple
try:
# old timm imports < 0.8.1
from timm.models.layers.attention_pool2d import RotAttentionPool2d
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
except ImportError:
# new timm imports >= 0.8.1
from timm.layers import RotAttentionPool2d
from timm.layers import AttentionPool2d as AbsAttentionPool2d
except ImportError:
timm = None
from .utils import freeze_batch_norm_2d
class TimmModel(nn.Module):
""" timm model adapter
# FIXME this adapter is a work in progress, may change in ways that break weight compat
"""
def __init__(
self,
model_name,
embed_dim,
image_size=224,
pool='avg',
proj='linear',
proj_bias=False,
drop=0.,
drop_path=None,
pretrained=False,
):
super().__init__()
if timm is None:
raise RuntimeError("Please `pip install timm` to use timm models.")
self.image_size = to_2tuple(image_size)
timm_kwargs = {}
if drop_path is not None:
timm_kwargs['drop_path_rate'] = drop_path
self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs)
feat_size = self.trunk.default_cfg.get('pool_size', None)
feature_ndim = 1 if not feat_size else 2
if pool in ('abs_attn', 'rot_attn'):
assert feature_ndim == 2
# if attn pooling used, remove both classifier and default pool
self.trunk.reset_classifier(0, global_pool='')
else:
# reset global pool if pool config set, otherwise leave as network default
reset_kwargs = dict(global_pool=pool) if pool else {}
self.trunk.reset_classifier(0, **reset_kwargs)
prev_chs = self.trunk.num_features
head_layers = OrderedDict()
if pool == 'abs_attn':
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
prev_chs = embed_dim
elif pool == 'rot_attn':
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
prev_chs = embed_dim
else:
assert proj, 'projection layer needed if non-attention pooling is used.'
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
if proj == 'linear':
head_layers['drop'] = nn.Dropout(drop)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
elif proj == 'mlp':
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
self.head = nn.Sequential(head_layers)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
""" lock modules
Args:
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
"""
if not unlocked_groups:
# lock full model
for param in self.trunk.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self.trunk)
else:
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
try:
# FIXME import here until API stable and in an official release
from timm.models.helpers import group_parameters, group_modules
except ImportError:
raise RuntimeError(
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
matcher = self.trunk.group_matcher()
gparams = group_parameters(self.trunk, matcher)
max_layer_id = max(gparams.keys())
max_layer_id = max_layer_id - unlocked_groups
for group_idx in range(max_layer_id + 1):
group = gparams[group_idx]
for param in group:
self.trunk.get_parameter(param).requires_grad = False
if freeze_bn_stats:
gmodules = group_modules(self.trunk, matcher, reverse=True)
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
freeze_batch_norm_2d(self.trunk, gmodules)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
try:
self.trunk.set_grad_checkpointing(enable)
except Exception as e:
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
def forward(self, x):
x = self.trunk(x)
x = self.head(x)
return x

View File

@@ -0,0 +1,211 @@
""" CLIP tokenizer
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import gzip
import html
import os
from functools import lru_cache
from typing import Union, List
import ftfy
import regex as re
import torch
# https://stackoverflow.com/q/62691279
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@lru_cache()
def default_bpe():
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
quality_metric_path = os.path.join(project_root, 'models', 'QualityMetric')
return os.path.join(quality_metric_path, "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
if not special_tokens:
special_tokens = ['<start_of_text>', '<end_of_text>']
else:
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
vocab.extend(special_tokens)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {t:t for t in special_tokens}
special = "|".join(special_tokens)
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
self.vocab_size = len(self.encoder)
self.all_special_ids = [self.encoder[t] for t in special_tokens]
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = self.encoder["<start_of_text>"]
eot_token = self.encoder["<end_of_text>"]
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = eot_token
result[i, :len(tokens)] = torch.tensor(tokens)
return result
class HFTokenizer:
"""HuggingFace tokenizer wrapper"""
def __init__(self, tokenizer_name: str):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def save_pretrained(self, dest):
self.tokenizer.save_pretrained(dest)
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
if isinstance(texts, str):
texts = [texts]
texts = [whitespace_clean(basic_clean(text)) for text in texts]
input_ids = self.tokenizer(
texts,
return_tensors='pt',
max_length=context_length,
padding='max_length',
truncation=True,
).input_ids
return input_ids

View File

@@ -0,0 +1,216 @@
import warnings
from dataclasses import dataclass, asdict
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from functools import partial
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
CenterCrop
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
@dataclass
class AugmentationCfg:
scale: Tuple[float, float] = (0.9, 1.0)
ratio: Optional[Tuple[float, float]] = None
color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
interpolation: Optional[str] = None
re_prob: Optional[float] = None
re_count: Optional[int] = None
use_timm: bool = False
class ResizeMaxSize(nn.Module):
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
super().__init__()
if not isinstance(max_size, int):
raise TypeError(f"Size should be int. Got {type(max_size)}")
self.max_size = max_size
self.interpolation = interpolation
self.fn = min if fn == 'min' else min
self.fill = fill
def forward(self, img):
if isinstance(img, torch.Tensor):
height, width = img.shape[1:]
else:
width, height = img.size
scale = self.max_size / float(max(height, width))
if scale != 1.0:
new_size = tuple(round(dim * scale) for dim in (height, width))
img = F.resize(img, new_size, self.interpolation)
pad_h = self.max_size - new_size[0]
pad_w = self.max_size - new_size[1]
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
return img
def _convert_to_rgb_or_rgba(image):
if image.mode == 'RGBA':
return image
else:
return image.convert('RGB')
# def transform_and_split(merged, transform_fn, normalize_fn):
# transformed = transform_fn(merged)
# crop_img, crop_label = torch.split(transformed, [3,1], dim=0)
# # crop_img = _convert_to_rgb(crop_img)
# crop_img = normalize_fn(ToTensor()(crop_img))
# return crop_img, crop_label
class MaskAwareNormalize(nn.Module):
def __init__(self, mean, std):
super().__init__()
self.normalize = Normalize(mean=mean, std=std)
def forward(self, tensor):
if tensor.shape[0] == 4:
return torch.cat([self.normalize(tensor[:3]), tensor[3:]], dim=0)
else:
return self.normalize(tensor)
def image_transform(
image_size: int,
is_train: bool,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_longest_max: bool = False,
fill_color: int = 0,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
):
mean = mean or OPENAI_DATASET_MEAN
if not isinstance(mean, (list, tuple)):
mean = (mean,) * 3
std = std or OPENAI_DATASET_STD
if not isinstance(std, (list, tuple)):
std = (std,) * 3
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
image_size = image_size[0]
if isinstance(aug_cfg, dict):
aug_cfg = AugmentationCfg(**aug_cfg)
else:
aug_cfg = aug_cfg or AugmentationCfg()
normalize = MaskAwareNormalize(mean=mean, std=std)
if is_train:
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
use_timm = aug_cfg_dict.pop('use_timm', False)
if use_timm:
assert False, "not tested for augmentation with mask"
from timm.data import create_transform # timm can still be optional
if isinstance(image_size, (tuple, list)):
assert len(image_size) >= 2
input_size = (3,) + image_size[-2:]
else:
input_size = (3, image_size, image_size)
# by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
aug_cfg_dict.setdefault('interpolation', 'random')
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
train_transform = create_transform(
input_size=input_size,
is_training=True,
hflip=0.,
mean=mean,
std=std,
re_mode='pixel',
**aug_cfg_dict,
)
else:
train_transform = Compose([
_convert_to_rgb_or_rgba,
ToTensor(),
RandomResizedCrop(
image_size,
scale=aug_cfg_dict.pop('scale'),
interpolation=InterpolationMode.BICUBIC,
),
normalize,
])
if aug_cfg_dict:
warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
return train_transform
else:
transforms = [
_convert_to_rgb_or_rgba,
ToTensor(),
]
if resize_longest_max:
transforms.extend([
ResizeMaxSize(image_size, fill=fill_color)
])
else:
transforms.extend([
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
])
transforms.extend([
normalize,
])
return Compose(transforms)
# def image_transform_region(
# image_size: int,
# is_train: bool,
# mean: Optional[Tuple[float, ...]] = None,
# std: Optional[Tuple[float, ...]] = None,
# resize_longest_max: bool = False,
# fill_color: int = 0,
# aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
# ):
# mean = mean or OPENAI_DATASET_MEAN
# if not isinstance(mean, (list, tuple)):
# mean = (mean,) * 3
# std = std or OPENAI_DATASET_STD
# if not isinstance(std, (list, tuple)):
# std = (std,) * 3
# if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
# # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
# image_size = image_size[0]
# if isinstance(aug_cfg, dict):
# aug_cfg = AugmentationCfg(**aug_cfg)
# else:
# aug_cfg = aug_cfg or AugmentationCfg()
# normalize = Normalize(mean=mean, std=std)
# if is_train:
# aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
# transform = Compose([
# RandomResizedCrop(
# image_size,
# scale=aug_cfg_dict.pop('scale'),
# interpolation=InterpolationMode.BICUBIC,
# ),
# ])
# train_transform = Compose([
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize)
# ])
# return train_transform
# else:
# if resize_longest_max:
# transform = [
# ResizeMaxSize(image_size, fill=fill_color)
# ]
# val_transform = Compose([
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
# ])
# else:
# transform = [
# Resize(image_size, interpolation=InterpolationMode.BICUBIC),
# CenterCrop(image_size),
# ]
# val_transform = Compose([
# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
# ])
# return val_transform

View File

@@ -0,0 +1,727 @@
from collections import OrderedDict
import math
from typing import Callable, Optional, Sequence, Tuple
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from .utils import to_2tuple
class LayerNormFp32(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class QuickGELU(nn.Module):
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
def forward(self, x):
if not self.training or self.prob == 0.:
return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
scaled_cosine=False,
scale_heads=False,
logit_scale_max=math.log(1. / 0.01),
attn_drop=0.,
proj_drop=0.
):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.logit_scale_max = logit_scale_max
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
if self.scaled_cosine:
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
else:
self.logit_scale = None
self.attn_drop = nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
L, N, C = x.shape
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
if self.logit_scale is not None:
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
attn = attn.view(N, self.num_heads, L, L) * logit_scale
attn = attn.view(-1, L, L)
else:
q = q * self.scale
attn = torch.bmm(q, k.transpose(-1, -2))
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
if self.head_scale is not None:
x = x.view(N, self.num_heads, L, C) * self.head_scale
x = x.view(-1, L, C)
x = x.transpose(0, 1).reshape(L, N, C)
x = self.out_proj(x)
x = self.out_drop(x)
return x
class AttentionalPooler(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
n_head: int = 8,
n_queries: int = 256,
norm_layer: Callable = LayerNorm
):
super().__init__()
self.query = nn.Parameter(torch.randn(n_queries, d_model))
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
self.ln_q = norm_layer(d_model)
self.ln_k = norm_layer(context_dim)
def forward(self, x: torch.Tensor):
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
return out.permute(1, 0, 2) # LND -> NLD
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
is_cross_attention: bool = False,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(
q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
)[0]
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class CustomResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
scale_cosine_attn: bool = False,
scale_heads: bool = False,
scale_attn: bool = False,
scale_fc: bool = False,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = Attention(
d_model, n_head,
scaled_cosine=scale_cosine_attn,
scale_heads=scale_heads,
)
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False
self.resblocks = nn.ModuleList([
ResidualAttentionBlock(
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
class VisionTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
ls_init_value: float = None,
global_average_pool: bool = False,
attentional_pool: bool = False,
n_queries: int = 256,
attn_pooler_heads: int = 8,
output_dim: int = 512,
patch_dropout: float = 0.,
input_patchnorm: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False
):
super().__init__()
self.output_tokens = output_tokens
image_height, image_width = self.image_size = to_2tuple(image_size)
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.output_dim = output_dim
# whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
self.input_patchnorm = input_patchnorm
if input_patchnorm:
patch_input_dim = patch_height * patch_width * 3
self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
self.conv1 = nn.Linear(patch_input_dim, width)
else:
self.patchnorm_pre_ln = nn.Identity()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
# class embeddings and positional embeddings
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
self.ln_pre = norm_layer(width)
self.transformer = Transformer(
width,
layers,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
self.global_average_pool = global_average_pool
if attentional_pool:
self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
else:
self.attn_pool = None
self.ln_post = norm_layer(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
self.init_parameters()
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
for param in self.parameters():
param.requires_grad = False
if unlocked_groups != 0:
groups = [
[
self.conv1,
self.class_embedding,
self.positional_embedding,
self.ln_pre,
],
*self.transformer.resblocks[:-1],
[
self.transformer.resblocks[-1],
self.ln_post,
],
self.proj,
]
def _unlock(x):
if isinstance(x, Sequence):
for g in x:
_unlock(g)
else:
if isinstance(x, torch.nn.Parameter):
x.requires_grad = True
else:
for p in x.parameters():
p.requires_grad = True
_unlock(groups[-unlocked_groups:])
def init_parameters(self):
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
# TODO experiment if default PyTorch init, below, or alternate init is best.
# nn.init.normal_(self.class_embedding, std=self.scale)
# nn.init.normal_(self.positional_embedding, std=self.scale)
#
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
# attn_std = self.transformer.width ** -0.5
# fc_std = (2 * self.transformer.width) ** -0.5
# for block in self.transformer.resblocks:
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
#
# if self.text_projection is not None:
# nn.init.normal_(self.text_projection, std=self.scale)
pass
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_average_pool:
return x.mean(dim=1), x
else:
return x[:, 0], x[:, 1:]
def forward(self, x: torch.Tensor, skip_pool: bool = False):
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
if self.input_patchnorm:
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1])
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
x = self.patchnorm_pre_ln(x)
x = self.conv1(x)
else:
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat(
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.patch_dropout(x)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
if skip_pool:
return x
if self.attn_pool is not None:
x = self.attn_pool(x)
x = self.ln_post(x)
pooled, tokens = self._global_pool(x)
else:
pooled, tokens = self._global_pool(x)
pooled = self.ln_post(pooled)
if self.proj is not None:
pooled = pooled @ self.proj
if self.output_tokens:
return pooled, tokens
return pooled
class TextTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
ls_init_value: float = None,
output_dim: int = 512,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
embed_cls: bool = False,
pad_id: int = 0,
output_tokens: bool = False,
):
super().__init__()
self.output_tokens = output_tokens
self.num_pos = self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.heads = heads
self.pad_id = pad_id
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
if embed_cls:
self.cls_emb = nn.Parameter(torch.empty(width))
self.num_pos += 1
else:
self.cls_emb = None
self.token_embedding = nn.Embedding(vocab_size, width)
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
self.transformer = Transformer(
width=width,
layers=layers,
heads=heads,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
self.ln_final = norm_layer(width)
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
self.init_parameters()
def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if self.cls_emb is not None:
nn.init.normal_(self.cls_emb, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def build_cls_mask(self, text, cast_dtype: torch.dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask
def _repeat(self, t, N: int):
return t.reshape(1, 1, -1).repeat(N, 1, 1)
def forward(self, text):
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
attn_mask = self.attn_mask
if self.cls_emb is not None:
seq_len += 1
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
cls_mask = self.build_cls_mask(text, cast_dtype)
attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
if self.cls_emb is not None:
pooled, tokens = x[:, -1], x[:, :-1]
pooled = self.ln_final(pooled)
else:
x = self.ln_final(x)
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
if self.text_projection is not None:
pooled = pooled @ self.text_projection
if self.output_tokens:
return pooled, tokens
return pooled
class MultimodalTransformer(Transformer):
def __init__(
self,
width: int,
layers: int,
heads: int,
context_length: int = 77,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_dim: int = 512,
):
super().__init__(
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
self.context_length = context_length
self.cross_attn = nn.ModuleList([
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
is_cross_attention=True,
)
for _ in range(layers)
])
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
self.ln_final = norm_layer(width)
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
def init_parameters(self):
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
for block in self.transformer.cross_attn:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(self, image_embs, text_embs):
text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
seq_len = text_embs.shape[0]
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
else:
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
x = text_embs.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
if self.text_projection is not None:
x = x @ self.text_projection
return x
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable

View File

@@ -0,0 +1,60 @@
from itertools import repeat
import collections.abc
from torch import nn as nn
from torchvision.ops.misc import FrozenBatchNorm2d
def freeze_batch_norm_2d(module, module_match={}, name=''):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
module_match (dict): Dictionary of full module names to freeze (all if empty)
name (str): Full module name (prefix)
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
is_match = True
if module_match:
is_match = name in module_match
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for child_name, child in module.named_children():
full_child_name = '.'.join([name, child_name]) if name else child_name
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
if new_child is not child:
res.add_module(child_name, new_child)
return res
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = lambda n, x: _ntuple(n)(x)

View File

@@ -0,0 +1 @@
__version__ = '2.16.0'

View File

@@ -0,0 +1,112 @@
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel
from typing import List, Union
import os
from .config import MODEL_PATHS
class PickScore(torch.nn.Module):
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
super().__init__()
"""Initialize the Selector with a processor and model.
Args:
device (Union[str, torch.device]): The device to load the model on.
"""
self.device = device if isinstance(device, torch.device) else torch.device(device)
processor_name_or_path = path.get("clip")
model_pretrained_name_or_path = path.get("pickscore")
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
def _calculate_score(self, image: torch.Tensor, prompt: str, softmax: bool = False) -> float:
"""Calculate the score for a single image and prompt.
Args:
image (torch.Tensor): The processed image tensor.
prompt (str): The prompt text.
softmax (bool): Whether to apply softmax to the scores.
Returns:
float: The score for the image.
"""
with torch.no_grad():
# Prepare text inputs
text_inputs = self.processor(
text=prompt,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(self.device)
# Embed images and text
image_embs = self.model.get_image_features(pixel_values=image)
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
text_embs = self.model.get_text_features(**text_inputs)
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
# Compute score
score = (text_embs @ image_embs.T)[0]
if softmax:
# Apply logit scale and softmax
score = torch.softmax(self.model.logit_scale.exp() * score, dim=-1)
return score.cpu().item()
@torch.no_grad()
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]:
"""Score the images based on the prompt.
Args:
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
prompt (str): The prompt text.
softmax (bool): Whether to apply softmax to the scores.
Returns:
List[float]: List of scores for the images.
"""
try:
if isinstance(images, (str, Image.Image)):
# Single image
if isinstance(images, str):
pil_image = Image.open(images)
else:
pil_image = images
# Prepare image inputs
image_inputs = self.processor(
images=pil_image,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(self.device)
return [self._calculate_score(image_inputs["pixel_values"], prompt, softmax)]
elif isinstance(images, list):
# Multiple images
scores = []
for one_image in images:
if isinstance(one_image, str):
pil_image = Image.open(one_image)
elif isinstance(one_image, Image.Image):
pil_image = one_image
else:
raise TypeError("The type of parameter images is illegal.")
# Prepare image inputs
image_inputs = self.processor(
images=pil_image,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(self.device)
scores.append(self._calculate_score(image_inputs["pixel_values"], prompt, softmax))
return scores
else:
raise TypeError("The type of parameter images is illegal.")
except Exception as e:
raise RuntimeError(f"Error in scoring images: {e}")

View File

@@ -0,0 +1 @@
from .models import *

View File

@@ -0,0 +1,3 @@
from .base_model import *
from .clip_model import *
from .cross_modeling import *

View File

@@ -0,0 +1,7 @@
from dataclasses import dataclass
@dataclass
class BaseModelConfig:
pass

View File

@@ -0,0 +1,146 @@
from dataclasses import dataclass
from transformers import CLIPModel as HFCLIPModel
from transformers import AutoTokenizer
from torch import nn, einsum
from .base_model import BaseModelConfig
from transformers import CLIPConfig
from typing import Any, Optional, Tuple, Union
import torch
from .cross_modeling import Cross_model
import json, os
class XCLIPModel(HFCLIPModel):
def __init__(self, config: CLIPConfig):
super().__init__(config)
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# pooled_output = text_outputs[1]
# text_features = self.text_projection(pooled_output)
last_hidden_state = text_outputs[0]
text_features = self.text_projection(last_hidden_state)
pooled_output = text_outputs[1]
text_features_EOS = self.text_projection(pooled_output)
# del last_hidden_state, text_outputs
# gc.collect()
return text_features, text_features_EOS
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# pooled_output = vision_outputs[1] # pooled_output
# image_features = self.visual_projection(pooled_output)
last_hidden_state = vision_outputs[0]
image_features = self.visual_projection(last_hidden_state)
return image_features
@dataclass
class ClipModelConfig(BaseModelConfig):
_target_: str = "diffsynth.extensions.QualityMetric.trainer.models.clip_model.CLIPModel"
pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"
class CLIPModel(nn.Module):
def __init__(self, ckpt, config_file=False):
super().__init__()
if config_file is None:
self.model = XCLIPModel.from_pretrained(ckpt)
else:
with open(os.path.join(ckpt, "config.json"), "r", encoding="utf-8") as f:
config = json.load(f)
config = CLIPConfig(**config)
self.model = XCLIPModel._from_config(config)
self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
def get_text_features(self, *args, **kwargs):
return self.model.get_text_features(*args, **kwargs)
def get_image_features(self, *args, **kwargs):
return self.model.get_image_features(*args, **kwargs)
def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
outputs = ()
text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
outputs += text_EOS,
image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
sim_text_condition = sim_text_condition / sim_text_condition.max()
mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77
mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
bc = int(image_f.shape[0]/2)
sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
outputs += sim0[:,0,:],
outputs += sim1[:,0,:],
return outputs
@property
def logit_scale(self):
return self.model.logit_scale
def save(self, path):
self.model.save_pretrained(path)

View File

@@ -0,0 +1,292 @@
import torch
from torch import einsum, nn
import torch.nn.functional as F
from einops import rearrange, repeat
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# normalization
# they use layernorm without bias, something that pytorch does not offer
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.register_buffer("bias", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)
# residual
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
# rotary positional embedding
# https://arxiv.org/abs/2104.09864
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame
class ParallelTransformerBlock(nn.Module):
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
super().__init__()
self.norm = LayerNorm(dim)
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
self.heads = heads
self.scale = dim_head**-0.5
self.rotary_emb = RotaryEmbedding(dim_head)
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
self.ff_out = nn.Sequential(
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
)
self.register_buffer("pos_emb", None, persistent=False)
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n]
pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("pos_emb", pos_emb, persistent=False)
return pos_emb
def forward(self, x, attn_mask=None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device, h = x.shape[1], x.device, self.heads
# pre layernorm
x = self.norm(x)
# attention queries, keys, values, and feedforward inner
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# https://arxiv.org/abs/1911.02150
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# rotary embeddings
positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
# scale
q = q * self.scale
# similarity
sim = einsum("b h i d, b j d -> b h i j", q, k)
# extra attention mask - for masking out attention from text CLS token to padding
if exists(attn_mask):
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
# attention
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
# aggregate values
out = einsum("b h i j, b j d -> b h i d", attn, v)
# merge heads
out = rearrange(out, "b h n d -> b n (h d)")
return self.attn_out(out) + self.ff_out(ff)
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim=None,
dim_head=64,
heads=12,
parallel_ff=False,
ff_mult=4,
norm_context=False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
# whether to have parallel feedforward
ff_inner_dim = ff_mult * dim
self.ff = nn.Sequential(
nn.Linear(dim, ff_inner_dim * 2, bias=False),
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
) if parallel_ff else None
def forward(self, x, context, mask):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# pre-layernorm, for queries and context
x = self.norm(x)
context = self.context_norm(context)
# get queries
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
# scale
q = q * self.scale
# get key / values
k, v = self.to_kv(context).chunk(2, dim=-1)
# query / key similarity
sim = einsum('b h i d, b j d -> b h i j', q, k)
# attention
mask = mask.unsqueeze(1).repeat(1,self.heads,1,1)
sim = sim + mask # context mask
sim = sim - sim.amax(dim=-1, keepdim=True)
attn = sim.softmax(dim=-1)
# aggregate
out = einsum('b h i j, b j d -> b h i d', attn, v)
# merge and combine heads
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
# add parallel feedforward (for multimodal layers)
if exists(self.ff):
out = out + self.ff(x)
return out
class Cross_model(nn.Module):
def __init__(
self,
dim=512,
layer_num=4,
dim_head=64,
heads=8,
ff_mult=4
):
super().__init__()
self.layers = nn.ModuleList([])
for ind in range(layer_num):
self.layers.append(nn.ModuleList([
Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
]))
def forward(
self,
query_tokens,
context_tokens,
mask
):
for cross_attn, self_attn_ff in self.layers:
query_tokens = cross_attn(query_tokens, context_tokens,mask)
query_tokens = self_attn_ff(query_tokens)
return query_tokens

View File

@@ -73,7 +73,6 @@ try:
)
except Exception as exception:
kernels = None
logger.warning("Failed to load cpm_kernels:" + str(exception))
class W8A16Linear(torch.autograd.Function):
@@ -981,7 +980,7 @@ class Embedding(torch.nn.Module):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
embeddings = words_embeddings
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# Data format change to avoid explicit transposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:

View File

@@ -8,6 +8,7 @@ from .flux_dit import FluxDiT
from .hunyuan_dit import HunyuanDiT
from .cog_dit import CogDiT
from .hunyuan_video_dit import HunyuanVideoDiT
from .wan_video_dit import WanModel
@@ -197,7 +198,7 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
class GeneralLoRAFromPeft:
def __init__(self):
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT]
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
def fetch_device_dtype_from_state_dict(self, state_dict):

View File

@@ -69,7 +69,9 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
model_state_dict, extra_kwargs = state_dict_results, {}
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
with init_weights_on_device():
model= model_class(**extra_kwargs)
model = model_class(**extra_kwargs)
if hasattr(model, "eval"):
model = model.eval()
model.load_state_dict(model_state_dict, assign=True)
model = model.to(dtype=torch_dtype, device=device)
loaded_model_names.append(model_name)

View File

@@ -10,7 +10,7 @@
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Union, List
import torch, math
from torch import nn
from einops import rearrange, repeat
@@ -398,7 +398,7 @@ class RoPE1D:
* tokens: batch_size x ntokens x nheads x dim
* positions: batch_size x ntokens (t position of each token)
output:
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
* tokens after applying RoPE2D (batch_size x ntokens x nheads x dim)
"""
D = tokens.size(3)
assert positions.ndim == 2 # Batch, Seq
@@ -428,7 +428,7 @@ class RoPE3D(RoPE1D):
* tokens: batch_size x ntokens x nheads x dim
* rope_positions: list of (f, h, w)
output:
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
* tokens after applying RoPE2D (batch_size x ntokens x nheads x dim)
"""
assert sum(ch_split) == tokens.size(-1);
@@ -757,7 +757,7 @@ class StepVideoModel(torch.nn.Module):
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
use_additional_conditions: Optional[bool] = False,
caption_channels: Optional[int]|list|tuple = [6144, 1024],
caption_channels: Optional[Union[int, List, Tuple]] = [6144, 1024],
attention_type: Optional[str] = "torch",
):
super().__init__()

View File

@@ -88,7 +88,7 @@ class LLaMaEmbedding(nn.Module):
embeddings = embeddings.to(self.params_dtype)
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# Data format change to avoid explicit transposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
@@ -326,7 +326,7 @@ class MultiQueryAttention(nn.Module):
dim=-1,
)
# gather on 1st dimention
# gather on 1st dimension
xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
xk, xv = xkv.chunk(2, -1)
@@ -357,7 +357,7 @@ class MultiQueryAttention(nn.Module):
output = self.core_attention(xq, xk, xv,
cu_seqlens=cu_seqlens,
max_seq_len=max_seq_len)
# reduce-scatter only support first dimention now
# reduce-scatter only support first dimension now
output = rearrange(output, "b s h d -> s b (h d)").contiguous()
else:
xq, xk, xv = [

View File

@@ -55,7 +55,7 @@ class TileWorker:
def io_scale(self, model_output, tile_size):
# Determine the size modification happend in forward_fn
# Determine the size modification happened in forward_fn
# We only consider the same scale on height and width.
io_scale = model_output.shape[2] / tile_size
return io_scale

View File

@@ -0,0 +1,799 @@
import math
import torch
import torch.amp as amp
import torch.nn as nn
from tqdm import tqdm
from .utils import hash_state_dict_keys
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
import warnings
__all__ = ['WanModel']
def flash_attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# preprocess query
if q_lens is None:
q = half(q.flatten(0, 1))
q_lens = torch.tensor(
[lq] * b, dtype=torch.int32).to(
device=q.device, non_blocking=True)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
# preprocess key, value
if k_lens is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.tensor(
[lk] * b, dtype=torch.int32).to(
device=k.device, non_blocking=True)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
q = q.to(v.dtype)
k = k.to(v.dtype)
if q_scale is not None:
q = q * q_scale
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn(
'Flash attention 3 is not available, use flash attention 2 instead.'
)
# apply attention
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
# Note: dropout_p, window_size are not supported in FA3 now.
x = flash_attn_interface.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
seqused_q=None,
seqused_k=None,
max_seqlen_q=lq,
max_seqlen_k=lk,
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic)[0].unflatten(0, (b, lq))
elif FLASH_ATTN_2_AVAILABLE:
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))
elif SAGE_ATTN_AVAILABLE:
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
k = k.unsqueeze(0).transpose(1, 2).to(dtype)
v = v.unsqueeze(0).transpose(1, 2).to(dtype)
x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal)
x = x.transpose(1, 2).contiguous()
else:
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
k = k.unsqueeze(0).transpose(1, 2).to(dtype)
v = v.unsqueeze(0).transpose(1, 2).to(dtype)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = x.transpose(1, 2).contiguous()
# output
return x.type(out_dtype)
def create_sdpa_mask(q, k, q_lens, k_lens, causal=False):
b, lq, lk = q.size(0), q.size(1), k.size(1)
if q_lens is None:
q_lens = torch.tensor([lq] * b, dtype=torch.int32)
if k_lens is None:
k_lens = torch.tensor([lk] * b, dtype=torch.int32)
attn_mask = torch.zeros((b, lq, lk), dtype=torch.bool)
for i in range(b):
q_len, k_len = q_lens[i], k_lens[i]
attn_mask[i, q_len:, :] = True
attn_mask[i, :, k_len:] = True
if causal:
causal_mask = torch.triu(torch.ones((lq, lk), dtype=torch.bool), diagonal=1)
attn_mask[i, :, :] = torch.logical_or(attn_mask[i, :, :], causal_mask)
attn_mask = attn_mask.logical_not().to(q.device, non_blocking=True)
return attn_mask
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return flash_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn('Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.')
attn_mask = None
q = q.transpose(1, 2).to(dtype)
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
out = out.transpose(1, 2).contiguous()
return out
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
# calculation
sinusoid = torch.outer(
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
@amp.autocast(enabled=False, device_type="cuda")
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@amp.autocast(enabled=False, device_type="cuda")
def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
seq_len, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
class WanRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return self._norm(x.float()).type_as(x) * self.weight
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
return super().forward(x.float()).type_as(x)
class WanSelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, seq_lens, grid_sizes, freqs):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context, context_lens):
"""
x: [B, L1, C].
context: [B, L2, C].
context_lens: [B].
"""
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanI2VCrossAttention(WanSelfAttention):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6):
super().__init__(dim, num_heads, window_size, qk_norm, eps)
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = WanRMSNorm(
dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, context, context_lens):
"""
x: [B, L1, C].
context: [B, L2, C].
context_lens: [B].
"""
context_img = context[:, :257]
context = context[:, 257:]
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
v_img = self.v_img(context_img).view(b, -1, n, d)
img_x = flash_attention(q, k_img, v_img, k_lens=None)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# output
x = x.flatten(2)
img_x = img_x.flatten(2)
x = x + img_x
x = self.o(x)
return x
WANX_CROSSATTENTION_CLASSES = {
't2v_cross_attn': WanT2VCrossAttention,
'i2v_cross_attn': WanI2VCrossAttention,
}
class WanAttentionBlock(nn.Module):
def __init__(self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps)
self.norm3 = WanLayerNorm(
dim, eps,
elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WANX_CROSSATTENTION_CLASSES[cross_attn_type](
dim, num_heads, (-1, -1), qk_norm, eps)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim))
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32, device_type="cuda"):
e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
freqs)
with amp.autocast(dtype=torch.float32, device_type="cuda"):
x = x + y * e[2]
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
with amp.autocast(dtype=torch.float32, device_type="cuda"):
x = x + y * e[5]
return x
x = cross_attn_ffn(x, context, context_lens, e)
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, e):
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32, device_type="cuda"):
e = (self.modulation.to(dtype=e.dtype, device=e.device) + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
torch.nn.LayerNorm(out_dim))
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class WanModel(nn.Module):
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6):
super().__init__()
assert model_type in ['t2v', 'i2v']
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps)
for _ in range(num_layers)
])
# head
self.head = Head(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim)
# initialize weights
self.init_weights()
def forward(
self,
x,
timestep,
context,
seq_len,
clip_fea=None,
y=None,
use_gradient_checkpointing=False,
**kwargs,
):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
"""
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
device = x[0].device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32, device_type="cuda"):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, **kwargs,
use_reentrant=False,
)
else:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
x = torch.stack(x).float()
return x
def unpatchify(self, x, grid_sizes):
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
# init output layer
nn.init.zeros_(self.head.head.weight)
@staticmethod
def state_dict_converter():
return WanModelStateDictConverter()
class WanModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
config = {
"model_type": "t2v",
"patch_size": (1, 2, 2),
"text_len": 512,
"in_dim": 16,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"window_size": (-1, -1),
"qk_norm": True,
"cross_attn_norm": True,
"eps": 1e-6,
}
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
config = {
"model_type": "t2v",
"patch_size": (1, 2, 2),
"text_len": 512,
"in_dim": 16,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"window_size": (-1, -1),
"qk_norm": True,
"cross_attn_norm": True,
"eps": 1e-6,
}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"model_type": "i2v",
"patch_size": (1, 2, 2),
"text_len": 512,
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"window_size": (-1, -1),
"qk_norm": True,
"cross_attn_norm": True,
"eps": 1e-6,
}
else:
config = {}
return state_dict, config

View File

@@ -0,0 +1,904 @@
"""
Concise re-implementation of
``https://github.com/openai/CLIP'' and
``https://github.com/mlfoundations/open_clip''.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from .wan_video_dit import flash_attention
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
# compute attention
p = self.dropout.p if self.training else 0.0
x = F.scaled_dot_product_attention(q, k, v, mask, p)
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
# output
x = self.o(x)
x = self.dropout(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.post_norm = post_norm
self.eps = eps
# layers
self.attn = SelfAttention(dim, num_heads, dropout, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
nn.Dropout(dropout))
self.norm2 = nn.LayerNorm(dim, eps=eps)
def forward(self, x, mask):
if self.post_norm:
x = self.norm1(x + self.attn(x, mask))
x = self.norm2(x + self.ffn(x))
else:
x = x + self.attn(self.norm1(x), mask)
x = x + self.ffn(self.norm2(x))
return x
class XLMRoberta(nn.Module):
"""
XLMRobertaModel with no pooler and no LM head.
"""
def __init__(self,
vocab_size=250002,
max_seq_len=514,
type_size=1,
pad_id=1,
dim=1024,
num_heads=16,
num_layers=24,
post_norm=True,
dropout=0.1,
eps=1e-5):
super().__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.type_size = type_size
self.pad_id = pad_id
self.dim = dim
self.num_heads = num_heads
self.num_layers = num_layers
self.post_norm = post_norm
self.eps = eps
# embeddings
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
self.type_embedding = nn.Embedding(type_size, dim)
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
self.dropout = nn.Dropout(dropout)
# blocks
self.blocks = nn.ModuleList([
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
for _ in range(num_layers)
])
# norm layer
self.norm = nn.LayerNorm(dim, eps=eps)
def forward(self, ids):
"""
ids: [B, L] of torch.LongTensor.
"""
b, s = ids.shape
mask = ids.ne(self.pad_id).long()
# embeddings
x = self.token_embedding(ids) + \
self.type_embedding(torch.zeros_like(ids)) + \
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
if self.post_norm:
x = self.norm(x)
x = self.dropout(x)
# blocks
mask = torch.where(
mask.view(b, 1, 1, s).gt(0), 0.0,
torch.finfo(x.dtype).min)
for block in self.blocks:
x = block(x, mask)
# output
if not self.post_norm:
x = self.norm(x)
return x
def xlm_roberta_large(pretrained=False,
return_tokenizer=False,
device='cpu',
**kwargs):
"""
XLMRobertaLarge adapted from Huggingface.
"""
# params
cfg = dict(
vocab_size=250002,
max_seq_len=514,
type_size=1,
pad_id=1,
dim=1024,
num_heads=16,
num_layers=24,
post_norm=True,
dropout=0.1,
eps=1e-5)
cfg.update(**kwargs)
# init model
if pretrained:
from sora import DOWNLOAD_TO_CACHE
# init a meta model
with torch.device('meta'):
model = XLMRoberta(**cfg)
# load checkpoint
model.load_state_dict(
torch.load(
DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'),
map_location=device),
assign=True)
else:
# init a model on device
with torch.device(device):
model = XLMRoberta(**cfg)
# init tokenizer
if return_tokenizer:
from sora.data import HuggingfaceTokenizer
tokenizer = HuggingfaceTokenizer(
name='xlm-roberta-large',
seq_len=model.text_len,
clean='whitespace')
return model, tokenizer
else:
return model
def pos_interpolate(pos, seq_len):
if pos.size(1) == seq_len:
return pos
else:
src_grid = int(math.sqrt(pos.size(1)))
tar_grid = int(math.sqrt(seq_len))
n = pos.size(1) - src_grid * src_grid
return torch.cat([
pos[:, :n],
F.interpolate(
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
0, 3, 1, 2),
size=(tar_grid, tar_grid),
mode='bicubic',
align_corners=False).flatten(2).transpose(1, 2)
],
dim=1)
class QuickGELU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
causal=False,
attn_dropout=0.0,
proj_dropout=0.0):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.causal = causal
self.attn_dropout = attn_dropout
self.proj_dropout = proj_dropout
# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
# compute attention
p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = x.reshape(b, s, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
return x
class SwiGLU(nn.Module):
def __init__(self, dim, mid_dim):
super().__init__()
self.dim = dim
self.mid_dim = mid_dim
# layers
self.fc1 = nn.Linear(dim, mid_dim)
self.fc2 = nn.Linear(dim, mid_dim)
self.fc3 = nn.Linear(mid_dim, dim)
def forward(self, x):
x = F.silu(self.fc1(x)) * self.fc2(x)
x = self.fc3(x)
return x
class AttentionBlock(nn.Module):
def __init__(self,
dim,
mlp_ratio,
num_heads,
post_norm=False,
causal=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
norm_eps=1e-5):
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.post_norm = post_norm
self.causal = causal
self.norm_eps = norm_eps
# layers
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
proj_dropout)
self.norm2 = LayerNorm(dim, eps=norm_eps)
if activation == 'swi_glu':
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else:
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
if self.post_norm:
x = x + self.norm1(self.attn(x))
x = x + self.norm2(self.mlp(x))
else:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class AttentionPool(nn.Module):
def __init__(self,
dim,
mlp_ratio,
num_heads,
activation='gelu',
proj_dropout=0.0,
norm_eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.proj_dropout = proj_dropout
self.norm_eps = norm_eps
# layers
gain = 1.0 / math.sqrt(dim)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
# compute attention
x = flash_attention(q, k, v, version=2)
x = x.reshape(b, 1, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
# mlp
x = x + self.mlp(self.norm(x))
return x[:, 0]
class VisionTransformer(nn.Module):
def __init__(self,
image_size=224,
patch_size=16,
dim=768,
mlp_ratio=4,
out_dim=512,
num_heads=12,
num_layers=12,
pool_type='token',
pre_norm=True,
post_norm=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
if image_size % patch_size != 0:
print(
'[WARNING] image_size is not divisible by patch_size',
flush=True)
assert pool_type in ('token', 'token_fc', 'attn_pool')
out_dim = out_dim or dim
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size)**2
self.dim = dim
self.mlp_ratio = mlp_ratio
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.pool_type = pool_type
self.post_norm = post_norm
self.norm_eps = norm_eps
# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(
3,
dim,
kernel_size=patch_size,
stride=patch_size,
bias=not pre_norm)
if pool_type in ('token', 'token_fc'):
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(gain * torch.randn(
1, self.num_patches +
(1 if pool_type in ('token', 'token_fc') else 0), dim))
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
activation, attn_dropout, proj_dropout, norm_eps)
for _ in range(num_layers)
])
self.post_norm = LayerNorm(dim, eps=norm_eps)
# head
if pool_type == 'token':
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
elif pool_type == 'token_fc':
self.head = nn.Linear(dim, out_dim)
elif pool_type == 'attn_pool':
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
proj_dropout, norm_eps)
def forward(self, x, interpolation=False, use_31_block=False):
b = x.size(0)
# embeddings
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
if self.pool_type in ('token', 'token_fc'):
x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)
if interpolation:
e = pos_interpolate(self.pos_embedding, x.size(1))
else:
e = self.pos_embedding
e = e.to(dtype=x.dtype, device=x.device)
x = self.dropout(x + e)
if self.pre_norm is not None:
x = self.pre_norm(x)
# transformer
if use_31_block:
x = self.transformer[:-1](x)
return x
else:
x = self.transformer(x)
return x
class CLIP(nn.Module):
def __init__(self,
embed_dim=512,
image_size=224,
patch_size=16,
vision_dim=768,
vision_mlp_ratio=4,
vision_heads=12,
vision_layers=12,
vision_pool='token',
vision_pre_norm=True,
vision_post_norm=False,
vocab_size=49408,
text_len=77,
text_dim=512,
text_mlp_ratio=4,
text_heads=8,
text_layers=12,
text_causal=True,
text_pool='argmax',
text_head_bias=False,
logit_bias=None,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
super().__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_mlp_ratio = vision_mlp_ratio
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vision_pool = vision_pool
self.vision_pre_norm = vision_pre_norm
self.vision_post_norm = vision_post_norm
self.vocab_size = vocab_size
self.text_len = text_len
self.text_dim = text_dim
self.text_mlp_ratio = text_mlp_ratio
self.text_heads = text_heads
self.text_layers = text_layers
self.text_causal = text_causal
self.text_pool = text_pool
self.text_head_bias = text_head_bias
self.norm_eps = norm_eps
# models
self.visual = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
mlp_ratio=vision_mlp_ratio,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
pool_type=vision_pool,
pre_norm=vision_pre_norm,
post_norm=vision_post_norm,
activation=activation,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout,
norm_eps=norm_eps)
self.textual = TextTransformer(
vocab_size=vocab_size,
text_len=text_len,
dim=text_dim,
mlp_ratio=text_mlp_ratio,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
causal=text_causal,
pool_type=text_pool,
head_bias=text_head_bias,
activation=activation,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout,
norm_eps=norm_eps)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
if logit_bias is not None:
self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
# initialize weights
self.init_weights()
def forward(self, imgs, txt_ids):
"""
imgs: [B, 3, H, W] of torch.float32.
- mean: [0.48145466, 0.4578275, 0.40821073]
- std: [0.26862954, 0.26130258, 0.27577711]
txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
"""
xi = self.visual(imgs)
xt = self.textual(txt_ids)
return xi, xt
def init_weights(self):
# embeddings
nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
# attentions
for modality in ['visual', 'textual']:
dim = self.vision_dim if modality == 'visual' else self.text_dim
transformer = getattr(self, modality).transformer
proj_gain = (1.0 / math.sqrt(dim)) * (
1.0 / math.sqrt(2 * len(transformer)))
attn_gain = 1.0 / math.sqrt(dim)
mlp_gain = 1.0 / math.sqrt(2.0 * dim)
for block in transformer:
nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
nn.init.normal_(block.attn.proj.weight, std=proj_gain)
nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
nn.init.normal_(block.mlp[2].weight, std=proj_gain)
def param_groups(self):
groups = [{
'params': [
p for n, p in self.named_parameters()
if 'norm' in n or n.endswith('bias')
],
'weight_decay': 0.0
}, {
'params': [
p for n, p in self.named_parameters()
if not ('norm' in n or n.endswith('bias'))
]
}]
return groups
class XLMRobertaWithHead(XLMRoberta):
def __init__(self, **kwargs):
self.out_dim = kwargs.pop('out_dim')
super().__init__(**kwargs)
# head
mid_dim = (self.dim + self.out_dim) // 2
self.head = nn.Sequential(
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
nn.Linear(mid_dim, self.out_dim, bias=False))
def forward(self, ids):
# xlm-roberta
x = super().forward(ids)
# average pooling
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
# head
x = self.head(x)
return x
class XLMRobertaCLIP(nn.Module):
def __init__(self,
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool='token',
vision_pre_norm=True,
vision_post_norm=False,
activation='gelu',
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
super().__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_mlp_ratio = vision_mlp_ratio
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vision_pre_norm = vision_pre_norm
self.vision_post_norm = vision_post_norm
self.activation = activation
self.vocab_size = vocab_size
self.max_text_len = max_text_len
self.type_size = type_size
self.pad_id = pad_id
self.text_dim = text_dim
self.text_heads = text_heads
self.text_layers = text_layers
self.text_post_norm = text_post_norm
self.norm_eps = norm_eps
# models
self.visual = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
mlp_ratio=vision_mlp_ratio,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
pool_type=vision_pool,
pre_norm=vision_pre_norm,
post_norm=vision_post_norm,
activation=activation,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout,
norm_eps=norm_eps)
self.textual = None
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def forward(self, imgs, txt_ids):
"""
imgs: [B, 3, H, W] of torch.float32.
- mean: [0.48145466, 0.4578275, 0.40821073]
- std: [0.26862954, 0.26130258, 0.27577711]
txt_ids: [B, L] of torch.long.
Encoded by data.CLIPTokenizer.
"""
xi = self.visual(imgs)
xt = self.textual(txt_ids)
return xi, xt
def param_groups(self):
groups = [{
'params': [
p for n, p in self.named_parameters()
if 'norm' in n or n.endswith('bias')
],
'weight_decay': 0.0
}, {
'params': [
p for n, p in self.named_parameters()
if not ('norm' in n or n.endswith('bias'))
]
}]
return groups
def _clip(pretrained=False,
pretrained_name=None,
model_cls=CLIP,
return_transforms=False,
return_tokenizer=False,
tokenizer_padding='eos',
dtype=torch.float32,
device='cpu',
**kwargs):
# init model
if pretrained and pretrained_name:
from sora import BUCKET, DOWNLOAD_TO_CACHE
# init a meta model
with torch.device('meta'):
model = model_cls(**kwargs)
# checkpoint path
checkpoint = f'models/clip/{pretrained_name}'
if dtype in (torch.float16, torch.bfloat16):
suffix = '-' + {
torch.float16: 'fp16',
torch.bfloat16: 'bf16'
}[dtype]
if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'):
checkpoint = f'{checkpoint}{suffix}'
checkpoint += '.pth'
# load
model.load_state_dict(
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
assign=True,
strict=False)
else:
# init a model on device
with torch.device(device):
model = model_cls(**kwargs)
# set device
output = (model,)
# init transforms
if return_transforms:
# mean and std
if 'siglip' in pretrained_name.lower():
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
else:
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
# transforms
transforms = T.Compose([
T.Resize((model.image_size, model.image_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=mean, std=std)
])
output += (transforms,)
# init tokenizer
if return_tokenizer:
from sora import data
if 'siglip' in pretrained_name.lower():
tokenizer = data.HuggingfaceTokenizer(
name=f'timm/{pretrained_name}',
seq_len=model.text_len,
clean='canonicalize')
elif 'xlm' in pretrained_name.lower():
tokenizer = data.HuggingfaceTokenizer(
name='xlm-roberta-large',
seq_len=model.max_text_len - 2,
clean='whitespace')
elif 'mba' in pretrained_name.lower():
tokenizer = data.HuggingfaceTokenizer(
name='facebook/xlm-roberta-xl',
seq_len=model.max_text_len - 2,
clean='whitespace')
else:
tokenizer = data.CLIPTokenizer(
seq_len=model.text_len, padding=tokenizer_padding)
output += (tokenizer,)
return output[0] if len(output) == 1 else output
def clip_xlm_roberta_vit_h_14(
pretrained=False,
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
**kwargs):
cfg = dict(
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool='token',
activation='gelu',
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0)
cfg.update(**kwargs)
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
class WanImageEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
# init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False,
return_transforms=True,
return_tokenizer=False,
dtype=torch.float32,
device="cpu")
def encode_image(self, videos):
# preprocess
size = (self.model.image_size,) * 2
videos = torch.cat([
F.interpolate(
u,
size=size,
mode='bicubic',
align_corners=False) for u in videos
])
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
out = self.model.visual(videos, use_31_block=True)
return out
@staticmethod
def state_dict_converter():
return WanImageEncoderStateDictConverter()
class WanImageEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
state_dict_ = {}
for name, param in state_dict.items():
if name.startswith("textual."):
continue
name = "model." + name
state_dict_[name] = param
return state_dict_

View File

@@ -206,7 +206,7 @@ def init_weights(m):
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
class WanXTextEncoder(torch.nn.Module):
class WanTextEncoder(torch.nn.Module):
def __init__(self,
vocab=256384,
@@ -218,7 +218,7 @@ class WanXTextEncoder(torch.nn.Module):
num_buckets=32,
shared_pos=False,
dropout=0.1):
super(WanXTextEncoder, self).__init__()
super(WanTextEncoder, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
@@ -252,3 +252,18 @@ class WanXTextEncoder(torch.nn.Module):
x = self.norm(x)
x = self.dropout(x)
return x
@staticmethod
def state_dict_converter():
return WanTextEncoderStateDictConverter()
class WanTextEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -7,6 +7,15 @@ from tqdm import tqdm
CACHE_T = 2
def check_is_instance(model, module_class):
if isinstance(model, module_class):
return True
if hasattr(model, "module") and isinstance(model.module, module_class):
return True
return False
def block_causal_mask(x, block_size):
# params
b, n, s, _, device = *x.size(), x.device
@@ -205,7 +214,7 @@ class ResidualBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
@@ -342,14 +351,14 @@ class Encoder3d(nn.Module):
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
@@ -440,7 +449,7 @@ class Decoder3d(nn.Module):
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
@@ -454,7 +463,7 @@ class Decoder3d(nn.Module):
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
@@ -475,7 +484,7 @@ class Decoder3d(nn.Module):
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
if check_is_instance(m, CausalConv3d):
count += 1
return count
@@ -587,7 +596,7 @@ class VideoVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
class WanXVideoVAE(nn.Module):
class WanVideoVAE(nn.Module):
def __init__(self, z_dim=16):
super().__init__()
@@ -747,19 +756,21 @@ class WanXVideoVAE(nn.Module):
return video.float().clamp_(-1, 1)
def encode(self, videos, device, tiled=False, tile_size=(272, 272), tile_stride=(144, 128)):
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
videos = [video.to("cpu") for video in videos]
hidden_states = []
for video in videos:
video = video.unsqueeze(0)
if tiled:
assert tile_size[0] % self.upsampling_factor == 0 and tile_size[1] % self.upsampling_factor == 0, f"tile_size must be devisible by {self.upsampling_factor}"
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
else:
hidden_state = self.single_encode(video, device)
hidden_state = hidden_state.squeeze(0)
hidden_states.append(hidden_state)
hidden_states = torch.stack(hidden_states)
return hidden_states
@@ -774,21 +785,24 @@ class WanXVideoVAE(nn.Module):
video = self.single_decode(hidden_state, device)
video = video.squeeze(0)
videos.append(video)
videos = torch.stack(videos)
return videos
@staticmethod
def state_dict_converter():
return WanXVideoVAEStateDictConverter()
return WanVideoVAEStateDictConverter()
class WanXVideoVAEStateDictConverter:
class WanVideoVAEStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict_ = {}
for name in state_dict['model_state']:
state_dict_['model.' + name] = state_dict['model_state'][name]
if 'model_state' in state_dict:
state_dict = state_dict['model_state']
for name in state_dict:
state_dict_['model.' + name] = state_dict[name]
return state_dict_

View File

@@ -11,4 +11,5 @@ from .omnigen_image import OmnigenImagePipeline
from .pipeline_runner import SDVideoPipelineRunner
from .hunyuan_video import HunyuanVideoPipeline
from .step_video import StepVideoPipeline
from .wan_video import WanVideoPipeline
KolorsImagePipeline = SDXLImagePipeline

View File

@@ -16,7 +16,7 @@ class OmniGenCache(DynamicCache):
def __init__(self,
num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
if not torch.cuda.is_available():
print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
print("No available GPU, offload_kv_cache will be set to False, which will result in large memory usage and time cost when input multiple images!!!")
offload_kv_cache = False
raise RuntimeError("OffloadedCache can only be used with a GPU")
super().__init__()

View File

@@ -0,0 +1,276 @@
from ..models import ModelManager
from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_image_encoder import WanImageEncoder
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import WanPrompter
import torch, os
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
self.text_encoder: WanTextEncoder = None
self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['text_encoder', 'dit', 'vae']
self.height_division_factor = 16
self.width_division_factor = 16
def enable_vram_management(self, num_persistent_param_in_dit=None):
dtype = next(iter(self.text_encoder.parameters())).dtype
enable_vram_management(
self.text_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
T5RelativeEmbedding: AutoWrappedModule,
T5LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
WanLayerNorm: AutoWrappedModule,
WanRMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
self.vae,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
RMS_norm: AutoWrappedModule,
CausalConv3d: AutoWrappedModule,
Upsample: AutoWrappedModule,
torch.nn.SiLU: AutoWrappedModule,
torch.nn.Dropout: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
if self.image_encoder is not None:
dtype = next(iter(self.image_encoder.parameters())).dtype
enable_vram_management(
self.image_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
if text_encoder_model_and_path is not None:
self.text_encoder, tokenizer_path = text_encoder_model_and_path
self.prompter.fetch_models(self.text_encoder)
self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
return pipe
def denoising_model(self):
return self.dit
def encode_prompt(self, prompt, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
return {"context": prompt_emb}
def encode_image(self, image, num_frames, height, width):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
y = torch.concat([msk, y])
return {"clip_fea": clip_context, "y": [y]}
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
def prepare_extra_input(self, latents=None):
return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
input_image=None,
input_video=None,
denoising_strength=1.0,
seed=None,
rand_device="cpu",
height=480,
width=832,
num_frames=81,
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
tiled=True,
tile_size=(30, 52),
tile_stride=(15, 26),
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Parameter check
height, width = self.check_resize_height_width(height, width)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
# Initialize noise
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
if input_video is not None:
self.load_models_to_device(['vae'])
input_video = self.preprocess_images(input_video)
input_video = torch.stack(input_video, dim=2)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = noise
# Encode prompts
self.load_models_to_device(["text_encoder"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
# Encode image
if input_image is not None and self.image_encoder is not None:
self.load_models_to_device(["image_encoder", "vae"])
image_emb = self.encode_image(input_image, num_frames, height, width)
else:
image_emb = {}
# Extra input
extra_input = self.prepare_extra_input(latents)
# Denoise
self.load_models_to_device(["dit"])
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
# Inference
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Decode
self.load_models_to_device(['vae'])
frames = self.decode_video(latents, **tiler_kwargs)
self.load_models_to_device([])
frames = self.tensor2video(frames[0])
return frames

View File

@@ -9,4 +9,4 @@ from .omost import OmostPromter
from .cog_prompter import CogPrompter
from .hunyuan_video_prompter import HunyuanVideoPrompter
from .stepvideo_prompter import StepVideoPrompter
from .wanx_prompter import WanXPrompter
from .wan_prompter import WanPrompter

View File

@@ -1,11 +1,10 @@
from .base_prompter import BasePrompter
from ..models.wanx_text_encoder import WanXTextEncoder
from ..models.wan_video_text_encoder import WanTextEncoder
from transformers import AutoTokenizer
import os, torch
import ftfy
import html
import string
import regex as re
@@ -14,11 +13,13 @@ def basic_clean(text):
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace('_', ' ')
if keep_punctuation_exact_string:
@@ -31,6 +32,7 @@ def canonicalize(text, keep_punctuation_exact_string=None):
text = re.sub(r'\s+', ' ', text)
return text.strip()
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
@@ -78,21 +80,25 @@ class HuggingfaceTokenizer:
text = canonicalize(basic_clean(text))
return text
class WanXPrompter(BasePrompter):
class WanPrompter(BasePrompter):
def __init__(self, tokenizer_path=None, text_len=512):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(
base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
super().__init__()
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean='whitespace')
self.text_len = text_len
self.text_encoder = None
self.fetch_tokenizer(tokenizer_path)
def fetch_models(self, text_encoder: WanXTextEncoder = None):
def fetch_tokenizer(self, tokenizer_path=None):
if tokenizer_path is not None:
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
def fetch_models(self, text_encoder: WanTextEncoder = None):
self.text_encoder = text_encoder
def encode_prompt(self, prompt, device="cuda"):
def encode_prompt(self, prompt, positive=True, device="cuda"):
prompt = self.process_prompt(prompt, positive=positive)
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
ids = ids.to(device)
mask = mask.to(device)
@@ -100,4 +106,3 @@ class WanXPrompter(BasePrompter):
prompt_emb = self.text_encoder(ids, mask)
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
return prompt_emb

View File

@@ -15,7 +15,9 @@ class FlowMatchScheduler():
self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
if shift is not None:
self.shift = shift
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]

View File

@@ -250,6 +250,17 @@ def add_general_parsers(parser):
default=None,
help="Pretrained LoRA path. Required if the training is resumed.",
)
parser.add_argument(
"--use_swanlab",
default=False,
action="store_true",
help="Whether to use SwanLab logger.",
)
parser.add_argument(
"--swanlab_mode",
default=None,
help="SwanLab mode (cloud or local).",
)
return parser
@@ -269,8 +280,21 @@ def launch_training_task(model, args):
batch_size=args.batch_size,
num_workers=args.dataloader_num_workers
)
# train
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project="diffsynth_studio",
name="diffsynth_studio",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=args.output_path,
)
logger = [swanlab_logger]
else:
logger = None
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
@@ -279,7 +303,8 @@ def launch_training_task(model, args):
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=logger,
)
trainer.fit(model=model, train_dataloaders=train_loader)

View File

@@ -16,14 +16,14 @@ The IP-Adapter model based on Stable Diffusion XL is more powerful. You have the
* Content controlling (original usage of IP-Adapter)
|First, we generate a rabbit.|Next, enable IP-Adapter and let the rabbit jump.|For comparision, disable IP-Adapter to see the generated image.|
|First, we generate a rabbit.|Next, enable IP-Adapter and let the rabbit jump.|For comparison, disable IP-Adapter to see the generated image.|
|-|-|-|
|![rabbit](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4b452634-ec57-414f-897a-f8c50c74a650)|![rabbit_to_jumping_rabbit](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/b93c5495-0b77-4d97-bcd3-3942858288f2)|![rabbit_to_jumping_rabbit_without_ipa](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/52f37195-65b3-4a38-8d9b-73df37311c15)|
* Style controlling (InstantStyle)
|First, we generate a rabbit.|Next, enable InstantStyle and convert the rabbit to a cat.|For comparision, disable IP-Adapter to see the generated image.|
|First, we generate a rabbit.|Next, enable InstantStyle and convert the rabbit to a cat.|For comparison, disable IP-Adapter to see the generated image.|
|-|-|-|
|![rabbit](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4b452634-ec57-414f-897a-f8c50c74a650)|![rabbit_to_cat](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/a006b281-f643-4ea9-b0da-712289c96059)|![rabbit_to_cat_without_ipa](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/189bd11e-7a10-4c09-8554-0eebde9150fd)|

View File

@@ -1,18 +0,0 @@
import torch
from diffsynth.prompters import WanXPrompter
from diffsynth.models.wanx_text_encoder import WanXTextEncoder
prompter = WanXPrompter('models/WanX/google/umt5-xxl')
text_encoder = WanXTextEncoder()
text_encoder.load_state_dict(torch.load('models/WanX/models_t5_umt5-xxl-enc-bf16.pth', map_location='cpu'))
text_encoder = text_encoder.eval().requires_grad_(False).to(dtype=torch.bfloat16, device='cuda')
prompter.fetch_models(text_encoder)
prompt = '维京战士双手挥舞着大斧,对抗猛犸象,黄昏,雪地中,漫天飞雪'
neg_prompt = '色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走'
prompt_emb = prompter.encode_prompt(prompt)
neg_prompt_emb = prompter.encode_prompt(neg_prompt)
print(prompt_emb[0]) # torch.Size([31, 4096])
print(neg_prompt_emb[0]) # torch.Size([126, 4096])

View File

@@ -1,46 +0,0 @@
import torch
import torchvision
import imageio
from diffsynth import ModelManager
def save_video(tensor,
save_file=None,
fps=30,
nrow=8,
normalize=True,
value_range=(-1, 1)):
tensor = tensor.clamp(min(value_range), max(value_range))
tensor = torch.stack([
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range)
for u in tensor.unbind(2)
],
dim=1).permute(1, 2, 3, 0) #frame, h, w, 3
tensor = (tensor * 255).type(torch.uint8).cpu()
# write video
writer = imageio.get_writer(
save_file, fps=fps, codec='libx264', quality=8)
for frame in tensor.numpy():
writer.append_data(frame)
writer.close()
torch.cuda.memory._record_memory_history()
model_manager = ModelManager(torch_dtype=torch.float, device="cuda")
model_manager.load_models([
"models/WanX/vae.pth",
])
vae = model_manager.fetch_model('wanxvideo_vae')
latents = [torch.load('sample.pt')]
videos = vae.decode(latents, device=latents[0].device, tiled=True)
back_encode = vae.encode(videos, device=latents[0].device, tiled=True)
videos_back_encode = vae.decode(back_encode, device=latents[0].device, tiled=False)
torch.cuda.memory._dump_snapshot("my_snapshot.pickle")
save_video(videos[0][None], save_file='example.mp4', fps=16, nrow=1)
save_video(videos_back_encode[0][None], save_file='example_backencode.mp4', fps=16, nrow=1)

View File

@@ -0,0 +1,15 @@
# Image Quality Metric
The image quality assessment functionality has been integrated into Diffsynth. We support the following models:
* [ImageReward](https://github.com/THUDM/ImageReward)
* [Aesthetic](https://github.com/christophschuhmann/improved-aesthetic-predictor)
* [PickScore](https://github.com/yuvalkirstain/pickscore)
* [CLIP](https://github.com/openai/CLIP)
* [HPSv2](https://github.com/tgxs002/HPSv2)
* [HPSv2.1](https://github.com/tgxs002/HPSv2)
* [MPS](https://github.com/Kwai-Kolors/MPS)
## Usage
See [`./image_quality_evaluation.py`](./image_quality_evaluation.py) for more details.

View File

@@ -0,0 +1,23 @@
from diffsynth.extensions.ImageQualityMetric import download_preference_model, load_preference_model
from modelscope import dataset_snapshot_download
from PIL import Image
# Download example image
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
allow_file_pattern="data/examples/ImageQualityMetric/image.jpg",
local_dir="./"
)
# Parameters
prompt = "an orange cat"
image = Image.open("data/examples/ImageQualityMetric/image.jpg")
device = "cuda"
cache_dir = "./models"
# Run preference models
for model_name in ["ImageReward", "Aesthetic", "PickScore", "CLIP", "HPSv2", "HPSv2.1", "MPS"]:
path = download_preference_model(model_name, cache_dir=cache_dir)
preference_model = load_preference_model(model_name, device=device, path=path)
print(model_name, preference_model.score(image, prompt))

View File

@@ -45,7 +45,7 @@ file_name,text
04.jpg,a dog
```
Note that if the model is Chinese model (for example, Hunyuan-DiT and Kolors), we recommand to use Chinese texts in the dataset. For example
Note that if the model is Chinese model (for example, Hunyuan-DiT and Kolors), we recommend to use Chinese texts in the dataset. For example
```
file_name,text
@@ -526,7 +526,7 @@ models/stable_diffusion_xl
└── sd_xl_base_1.0.safetensors
```
We observed that Stable Diffusion XL is not float16-safe, thus we recommand users to use float32.
We observed that Stable Diffusion XL is not float16-safe, thus we recommend users to use float32.
```
CUDA_VISIBLE_DEVICES="0" python examples/train/stable_diffusion_xl/train_sdxl_lora.py \

View File

@@ -41,7 +41,7 @@ def parse_args():
type=str,
default=None,
required=True,
help="Path to pretrained models, seperated by comma. For example, SD3: `models/stable_diffusion_3/sd3_medium_incl_clips_t5xxlfp16.safetensors`, SD3.5-large: `models/stable_diffusion_3/text_encoders/clip_g.safetensors,models/stable_diffusion_3/text_encoders/clip_l.safetensors,models/stable_diffusion_3/text_encoders/t5xxl_fp16.safetensors,models/stable_diffusion_3/sd3.5_large.safetensors`",
help="Path to pretrained models, separated by comma. For example, SD3: `models/stable_diffusion_3/sd3_medium_incl_clips_t5xxlfp16.safetensors`, SD3.5-large: `models/stable_diffusion_3/text_encoders/clip_g.safetensors,models/stable_diffusion_3/text_encoders/clip_l.safetensors,models/stable_diffusion_3/text_encoders/t5xxl_fp16.safetensors,models/stable_diffusion_3/sd3.5_large.safetensors`",
)
parser.add_argument(
"--lora_target_modules",

209
examples/wanvideo/README.md Normal file
View File

@@ -0,0 +1,209 @@
# Wan-Video
Wan-Video is a collection of video synthesis models open-sourced by Alibaba.
Before using this model, please install DiffSynth-Studio from **source code**.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
* [Sage Attention](https://github.com/thu-ml/SageAttention)
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
## Inference
### Wan-Video-1.3B-T2V
Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py).
Required VRAM: 6G
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
Put sunglasses on the dog.
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
### Wan-Video-14B-T2V
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
We present a detailed table here. The model is tested on a single A100.
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|-|-|-|-|-|
|torch.bfloat16|None (unlimited)|18.5s/it|40G||
|torch.bfloat16|7*10**9 (7B)|20.8s/it|24G||
|torch.bfloat16|0|23.4s/it|10G||
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|torch.float8_e4m3fn|0|24.0s/it|10G||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
### Wan-Video-14B-I2V
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
![Image](https://github.com/user-attachments/assets/adf8047f-7943-4aaa-a555-2b32dc415f39)
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
## Train
We support Wan-Video LoRA training and full training. Here is a tutorial. This is an experimental feature. Below is a video sample generated from the character Keqing LoRA:
https://github.com/user-attachments/assets/9bd8e30b-97e8-44f9-bb6f-da004ba376a9
Step 1: Install additional packages
```
pip install peft lightning pandas
```
Step 2: Prepare your dataset
You need to manage the training videos as follows:
```
data/example_dataset/
├── metadata.csv
└── train
├── video_00001.mp4
└── image_00002.jpg
```
`metadata.csv`:
```
file_name,text
video_00001.mp4,"video description"
image_00002.jpg,"video description"
```
We support both images and videos. An image is treated as a single frame of video.
Step 3: Data process
```shell
CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
--task data_process \
--dataset_path data/example_dataset \
--output_path ./models \
--text_encoder_path "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth" \
--vae_path "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" \
--tiled \
--num_frames 81 \
--height 480 \
--width 832
```
After that, some cached files will be stored in the dataset folder.
```
data/example_dataset/
├── metadata.csv
└── train
├── video_00001.mp4
├── video_00001.mp4.tensors.pth
├── video_00002.mp4
└── video_00002.mp4.tensors.pth
```
Step 4: Train
LoRA training:
```shell
CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
--task train \
--train_architecture lora \
--dataset_path data/example_dataset \
--output_path ./models \
--dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \
--steps_per_epoch 500 \
--max_epochs 10 \
--learning_rate 1e-4 \
--lora_rank 16 \
--lora_alpha 16 \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--accumulate_grad_batches 1 \
--use_gradient_checkpointing
```
Full training:
```shell
CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
--task train \
--train_architecture full \
--dataset_path data/example_dataset \
--output_path ./models \
--dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \
--steps_per_epoch 500 \
--max_epochs 10 \
--learning_rate 1e-4 \
--accumulate_grad_batches 1 \
--use_gradient_checkpointing
```
Step 5: Test
Test LoRA:
```python
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
])
model_manager.load_lora("models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0)
pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
video = pipe(
prompt="...",
negative_prompt="...",
num_inference_steps=50,
seed=0, tiled=True
)
save_video(video, "video.mp4", fps=30, quality=5)
```
Test fine-tuned base model:
```python
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([
"models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt",
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
])
pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
video = pipe(
prompt="...",
negative_prompt="...",
num_inference_steps=50,
seed=0, tiled=True
)
save_video(video, "video.mp4", fps=30, quality=5)
```

View File

@@ -0,0 +1,527 @@
import torch, os, imageio, argparse
from torchvision.transforms import v2
from einops import rearrange
import lightning as pl
import pandas as pd
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
from peft import LoraConfig, inject_adapter_in_model
import torchvision
from PIL import Image
class TextVideoDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
self.max_num_frames = max_num_frames
self.frame_interval = frame_interval
self.num_frames = num_frames
self.height = height
self.width = width
self.frame_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.Resize(size=(height, width), antialias=True),
v2.ToTensor(),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def crop_and_resize(self, image):
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
return image
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
reader = imageio.get_reader(file_path)
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
frames = []
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame)
frame = frame_process(frame)
frames.append(frame)
reader.close()
frames = torch.stack(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
return frames
def load_video(self, file_path):
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
return frames
def is_image(self, file_path):
file_ext_name = file_path.split(".")[-1]
if file_ext_name.lower() in ["jpg", "png", "webp"]:
return True
return False
def load_image(self, file_path):
frame = Image.open(file_path).convert("RGB")
frame = self.crop_and_resize(frame)
frame = self.frame_process(frame)
frame = rearrange(frame, "C H W -> C 1 H W")
return frame
def __getitem__(self, data_id):
text = self.text[data_id]
path = self.path[data_id]
if self.is_image(path):
video = self.load_image(path)
else:
video = self.load_video(path)
data = {"text": text, "video": video, "path": path}
return data
def __len__(self):
return len(self.path)
class LightningModelForDataProcess(pl.LightningModule):
def __init__(self, text_encoder_path, vae_path, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([text_encoder_path, vae_path])
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
def test_step(self, batch, batch_idx):
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
self.pipe.device = self.device
if video is not None:
prompt_emb = self.pipe.encode_prompt(text)
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
data = {"latents": latents, "prompt_emb": prompt_emb}
torch.save(data, path + ".tensors.pth")
class TensorDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, steps_per_epoch):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
print(len(self.path), "videos in metadata.")
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
print(len(self.path), "tensors cached in metadata.")
assert len(self.path) > 0
self.steps_per_epoch = steps_per_epoch
def __getitem__(self, index):
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
path = self.path[data_id]
data = torch.load(path, weights_only=True, map_location="cpu")
return data
def __len__(self):
return self.steps_per_epoch
class LightningModelForTrain(pl.LightningModule):
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([dit_path])
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
if train_architecture == "lora":
self.add_lora_to_model(
self.pipe.denoising_model(),
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_target_modules=lora_target_modules,
init_lora_weights=init_lora_weights,
pretrained_lora_path=pretrained_lora_path,
)
else:
self.pipe.denoising_model().requires_grad_(True)
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
def freeze_parameters(self):
# Freeze parameters
self.pipe.requires_grad_(False)
self.pipe.eval()
self.pipe.denoising_model().train()
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
# Add LoRA to UNet
self.lora_alpha = lora_alpha
if init_lora_weights == "kaiming":
init_lora_weights = True
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights=init_lora_weights,
target_modules=lora_target_modules.split(","),
)
model = inject_adapter_in_model(lora_config, model)
for param in model.parameters():
# Upcast LoRA parameters into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# Lora pretrained lora weights
if pretrained_lora_path is not None:
state_dict = load_state_dict(pretrained_lora_path)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
all_keys = [i for i, _ in model.named_parameters()]
num_updated_keys = len(all_keys) - len(missing_keys)
num_unexpected_keys = len(unexpected_keys)
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
def training_step(self, batch, batch_idx):
# Data
latents = batch["latents"].to(self.device)
prompt_emb = batch["prompt_emb"]
prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)]
# Loss
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
noise_pred = self.pipe.denoising_model()(
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.denoising_model().state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
checkpoint.update(lora_state_dict)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--task",
type=str,
default="data_process",
required=True,
choices=["data_process", "train"],
help="Task. `data_process` or `train`.",
)
parser.add_argument(
"--dataset_path",
type=str,
default=None,
required=True,
help="The path of the Dataset.",
)
parser.add_argument(
"--output_path",
type=str,
default="./",
help="Path to save the model.",
)
parser.add_argument(
"--text_encoder_path",
type=str,
default=None,
help="Path of text encoder.",
)
parser.add_argument(
"--vae_path",
type=str,
default=None,
help="Path of VAE.",
)
parser.add_argument(
"--dit_path",
type=str,
default=None,
help="Path of DiT.",
)
parser.add_argument(
"--tiled",
default=False,
action="store_true",
help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
)
parser.add_argument(
"--tile_size_height",
type=int,
default=34,
help="Tile size (height) in VAE.",
)
parser.add_argument(
"--tile_size_width",
type=int,
default=34,
help="Tile size (width) in VAE.",
)
parser.add_argument(
"--tile_stride_height",
type=int,
default=18,
help="Tile stride (height) in VAE.",
)
parser.add_argument(
"--tile_stride_width",
type=int,
default=16,
help="Tile stride (width) in VAE.",
)
parser.add_argument(
"--steps_per_epoch",
type=int,
default=500,
help="Number of steps per epoch.",
)
parser.add_argument(
"--num_frames",
type=int,
default=81,
help="Number of frames.",
)
parser.add_argument(
"--height",
type=int,
default=480,
help="Image height.",
)
parser.add_argument(
"--width",
type=int,
default=832,
help="Image width.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=1,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="Learning rate.",
)
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=1,
help="The number of batches in gradient accumulation.",
)
parser.add_argument(
"--max_epochs",
type=int,
default=1,
help="Number of epochs.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="q,k,v,o,ffn.0,ffn.2",
help="Layers with LoRA modules.",
)
parser.add_argument(
"--init_lora_weights",
type=str,
default="kaiming",
choices=["gaussian", "kaiming"],
help="The initializing method of LoRA weight.",
)
parser.add_argument(
"--training_strategy",
type=str,
default="auto",
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
help="Training strategy",
)
parser.add_argument(
"--lora_rank",
type=int,
default=4,
help="The dimension of the LoRA update matrices.",
)
parser.add_argument(
"--lora_alpha",
type=float,
default=4.0,
help="The weight of the LoRA update matrices.",
)
parser.add_argument(
"--use_gradient_checkpointing",
default=False,
action="store_true",
help="Whether to use gradient checkpointing.",
)
parser.add_argument(
"--train_architecture",
type=str,
default="lora",
choices=["lora", "full"],
help="Model structure to train. LoRA training or full training.",
)
parser.add_argument(
"--pretrained_lora_path",
type=str,
default=None,
help="Pretrained LoRA path. Required if the training is resumed.",
)
parser.add_argument(
"--use_swanlab",
default=False,
action="store_true",
help="Whether to use SwanLab logger.",
)
parser.add_argument(
"--swanlab_mode",
default=None,
help="SwanLab mode (cloud or local).",
)
args = parser.parse_args()
return args
def data_process(args):
dataset = TextVideoDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv"),
max_num_frames=args.num_frames,
frame_interval=1,
num_frames=args.num_frames,
height=args.height,
width=args.width
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=1,
num_workers=args.dataloader_num_workers
)
model = LightningModelForDataProcess(
text_encoder_path=args.text_encoder_path,
vae_path=args.vae_path,
tiled=args.tiled,
tile_size=(args.tile_size_height, args.tile_size_width),
tile_stride=(args.tile_stride_height, args.tile_stride_width),
)
trainer = pl.Trainer(
accelerator="gpu",
devices="auto",
default_root_dir=args.output_path,
)
trainer.test(model, dataloader)
def train(args):
dataset = TensorDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv"),
steps_per_epoch=args.steps_per_epoch,
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=1,
num_workers=args.dataloader_num_workers
)
model = LightningModelForTrain(
dit_path=args.dit_path,
learning_rate=args.learning_rate,
train_architecture=args.train_architecture,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
use_gradient_checkpointing=args.use_gradient_checkpointing,
pretrained_lora_path=args.pretrained_lora_path,
)
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project="wan",
name="wan",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=args.output_path,
)
logger = [swanlab_logger]
else:
logger = None
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=logger,
)
trainer.fit(model, dataloader)
if __name__ == '__main__':
args = parse_args()
if args.task == "data_process":
data_process(args)
elif args.task == "train":
train(args)

View File

@@ -0,0 +1,40 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download
# Download models
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Text-to-video
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=0, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)
# Video-to-video
video = VideoData("video1.mp4", height=480, width=832)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,戴着黑色墨镜,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_video=video, denoising_strength=0.7,
num_inference_steps=50,
seed=1, tiled=True
)
save_video(video, "video2.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,48 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1-I2V-14B-480P")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
[
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors",
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors",
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors",
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors",
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors",
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors",
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors",
],
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth",
],
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
# Download example image
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/input_image.jpg"
)
image = Image.open("data/examples/wan/input_image.jpg")
# Image-to-video
video = pipe(
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=image,
num_inference_steps=50,
seed=0, tiled=True
)
save_video(video, "video.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,36 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download
# Download models
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
[
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
],
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
],
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
# Text-to-video
video = pipe(
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=0, tiled=True
)
save_video(video, "video1.mp4", fps=25, quality=5)

View File

@@ -10,3 +10,4 @@ einops
sentencepiece
protobuf
modelscope
ftfy

View File

@@ -14,7 +14,7 @@ else:
setup(
name="diffsynth",
version="1.1.1",
version="1.1.2",
description="Enjoy the magic of Diffusion models!",
author="Artiprocher",
packages=find_packages(),