mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
support SD3 textual inversion
This commit is contained in:
@@ -20,7 +20,7 @@ class SD3Prompter(Prompter):
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
|
||||
super().__init__()
|
||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
|
||||
|
||||
@@ -61,17 +61,17 @@ class SD3Prompter(Prompter):
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
prompt, pure_prompt = self.process_prompt(prompt, positive=positive, require_pure_prompt=True)
|
||||
|
||||
# CLIP
|
||||
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer_1, 77, device)
|
||||
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, text_encoder_2, self.tokenizer_2, 77, device)
|
||||
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer, 77, device)
|
||||
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(pure_prompt, text_encoder_2, self.tokenizer_2, 77, device)
|
||||
|
||||
# T5
|
||||
if text_encoder_3 is None:
|
||||
prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device)
|
||||
else:
|
||||
prompt_emb_3 = self.encode_prompt_using_t5(prompt, text_encoder_3, self.tokenizer_3, 256, device)
|
||||
prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device)
|
||||
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
|
||||
|
||||
# Merge
|
||||
|
||||
@@ -111,14 +111,27 @@ but make sure there is a correlation between the input and output.\n\
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
def process_prompt(self, prompt, positive=True):
|
||||
def add_textual_inversion_tokens(self, prompt):
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
return prompt
|
||||
|
||||
def del_textual_inversion_tokens(self, prompt):
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, "")
|
||||
return prompt
|
||||
|
||||
def process_prompt(self, prompt, positive=True, require_pure_prompt=False):
|
||||
prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt)
|
||||
if positive and self.translator is not None:
|
||||
prompt = self.translator(prompt)
|
||||
print(f"Your prompt is translated: \"{prompt}\"")
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
return prompt
|
||||
if require_pure_prompt:
|
||||
return prompt, pure_prompt
|
||||
else:
|
||||
return prompt
|
||||
|
||||
Reference in New Issue
Block a user