Files
DiffSynth-Studio/diffsynth/prompts/__init__.py
2023-12-21 17:11:58 +08:00

108 lines
3.8 KiB
Python

from transformers import CLIPTokenizer
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, load_state_dict
import torch, os
def tokenize_long_prompt(tokenizer, prompt):
# Get model_max_length from self.tokenizer
length = tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore tokenizer.model_max_length
tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
class SDPrompter:
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.keyword_dict = {}
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda"):
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
def load_textual_inversion(self, textual_inversion_dict):
self.keyword_dict = {}
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
self.tokenizer.add_tokens(additional_tokens)
class SDXLPrompter:
def __init__(
self,
tokenizer_path="configs/stable_diffusion/tokenizer",
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
def encode_prompt(
self,
text_encoder: SDXLTextEncoder,
text_encoder_2: SDXLTextEncoder2,
prompt,
clip_skip=1,
clip_skip_2=2,
device="cuda"
):
# 1
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
# 2
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
# Merge
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
add_text_embeds = add_text_embeds[0:1]
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return add_text_embeds, prompt_emb