Files
DiffSynth-Studio/diffsynth/prompts/sdxl_tokenizer.py
2023-12-08 01:03:30 +08:00

46 lines
1.4 KiB
Python

from transformers import CLIPTokenizer
from .sd_tokenizer import SDTokenizer
class SDXLTokenizer(SDTokenizer):
def __init__(self):
super().__init__()
class SDXLTokenizer2:
def __init__(self, tokenizer_path="configs/stable_diffusion_xl/tokenizer_2"):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def __call__(self, prompt):
# Get model_max_length from self.tokenizer
length = self.tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
self.tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore self.tokenizer.model_max_length
self.tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids