support SD3 textual inversion

This commit is contained in:
Artiprocher
2024-07-05 13:36:54 +08:00
parent 9920b8d975
commit 518c6d6ac3
5 changed files with 70 additions and 13 deletions

View File

@@ -20,7 +20,7 @@ class SD3Prompter(Prompter):
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
super().__init__()
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
@@ -61,17 +61,17 @@ class SD3Prompter(Prompter):
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
prompt, pure_prompt = self.process_prompt(prompt, positive=positive, require_pure_prompt=True)
# CLIP
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer_1, 77, device)
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, text_encoder_2, self.tokenizer_2, 77, device)
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer, 77, device)
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(pure_prompt, text_encoder_2, self.tokenizer_2, 77, device)
# T5
if text_encoder_3 is None:
prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device)
else:
prompt_emb_3 = self.encode_prompt_using_t5(prompt, text_encoder_3, self.tokenizer_3, 256, device)
prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device)
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
# Merge