mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 08:40:47 +00:00
add new quality metric
This commit is contained in:
@@ -13,8 +13,16 @@ from transformers import BertTokenizer
|
||||
from .vit import VisionTransformer, interpolate_pos_embed
|
||||
|
||||
|
||||
def default_bert():
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
||||
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
||||
return os.path.join(model_path, "bert-base-uncased")
|
||||
|
||||
bert_model_path = default_bert()
|
||||
|
||||
def init_tokenizer():
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
||||
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
||||
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
||||
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
||||
|
||||
Reference in New Issue
Block a user