support SD3 textual inversion

This commit is contained in:
Artiprocher
2024-07-05 13:36:54 +08:00
parent 9920b8d975
commit 518c6d6ac3
5 changed files with 70 additions and 13 deletions

View File

@@ -20,7 +20,7 @@ class SD3Prompter(Prompter):
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
super().__init__()
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
@@ -61,17 +61,17 @@ class SD3Prompter(Prompter):
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
prompt, pure_prompt = self.process_prompt(prompt, positive=positive, require_pure_prompt=True)
# CLIP
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer_1, 77, device)
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, text_encoder_2, self.tokenizer_2, 77, device)
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer, 77, device)
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(pure_prompt, text_encoder_2, self.tokenizer_2, 77, device)
# T5
if text_encoder_3 is None:
prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device)
else:
prompt_emb_3 = self.encode_prompt_using_t5(prompt, text_encoder_3, self.tokenizer_3, 256, device)
prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device)
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
# Merge

View File

@@ -111,14 +111,27 @@ but make sure there is a correlation between the input and output.\n\
if "beautiful_prompt" in model_manager.model:
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
def process_prompt(self, prompt, positive=True):
def add_textual_inversion_tokens(self, prompt):
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
return prompt
def del_textual_inversion_tokens(self, prompt):
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, "")
return prompt
def process_prompt(self, prompt, positive=True, require_pure_prompt=False):
prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt)
if positive and self.translator is not None:
prompt = self.translator(prompt)
print(f"Your prompt is translated: \"{prompt}\"")
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
return prompt
if require_pure_prompt:
return prompt, pure_prompt
else:
return prompt