support SD3 LoRA

This commit is contained in:
Artiprocher
2024-07-10 10:07:02 +08:00
parent 8113f95278
commit 979a8814f1
13 changed files with 1030 additions and 32 deletions

View File

@@ -69,7 +69,7 @@ class SD3Prompter(Prompter):
# T5
if text_encoder_3 is None:
prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device)
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
else:
prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device)
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16

View File

@@ -124,6 +124,13 @@ but make sure there is a correlation between the input and output.\n\
return prompt
def process_prompt(self, prompt, positive=True, require_pure_prompt=False):
if isinstance(prompt, list):
prompt = [self.process_prompt(prompt_, positive=positive, require_pure_prompt=require_pure_prompt) for prompt_ in prompt]
if require_pure_prompt:
prompt, pure_prompt = [i[0] for i in prompt], [i[1] for i in prompt]
return prompt, pure_prompt
else:
return prompt
prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt)
if positive and self.translator is not None:
prompt = self.translator(prompt)