compatibility update

This commit is contained in:
Artiprocher
2023-12-23 20:13:41 +08:00
parent b30d0fa412
commit 66b3e995c2
27 changed files with 1051 additions and 398 deletions

View File

@@ -1,5 +1,5 @@
from transformers import CLIPTokenizer
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, load_state_dict
from transformers import CLIPTokenizer, AutoTokenizer
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2
import torch, os
@@ -35,14 +35,30 @@ def tokenize_long_prompt(tokenizer, prompt):
return input_ids
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
class BeautifulPrompt:
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
def __call__(self, raw_prompt):
model_input = self.template.format(raw_prompt=raw_prompt)
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=384,
do_sample=True,
temperature=0.9,
top_k=50,
top_p=0.95,
repetition_penalty=1.1,
num_return_sequences=1
)
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
outputs[:, input_ids.size(1):],
skip_special_tokens=True
)[0].strip()
return prompt
class SDPrompter:
@@ -50,11 +66,20 @@ class SDPrompter:
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.keyword_dict = {}
self.beautiful_prompt: BeautifulPrompt = None
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda"):
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
# Textual Inversion
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
# Beautiful Prompt
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
@@ -70,6 +95,16 @@ class SDPrompter:
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
self.tokenizer.add_tokens(additional_tokens)
def load_beautiful_prompt(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
if model_folder.endswith("v2"):
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""
class SDXLPrompter:
def __init__(
@@ -80,6 +115,8 @@ class SDXLPrompter:
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
self.keyword_dict = {}
self.beautiful_prompt: BeautifulPrompt = None
def encode_prompt(
self,
@@ -88,8 +125,19 @@ class SDXLPrompter:
prompt,
clip_skip=1,
clip_skip_2=2,
positive=True,
device="cuda"
):
# Textual Inversion
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
# Beautiful Prompt
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
# 1
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
@@ -105,3 +153,22 @@ class SDXLPrompter:
add_text_embeds = add_text_embeds[0:1]
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return add_text_embeds, prompt_emb
def load_textual_inversion(self, textual_inversion_dict):
self.keyword_dict = {}
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
self.tokenizer.add_tokens(additional_tokens)
def load_beautiful_prompt(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
if model_folder.endswith("v2"):
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""