support SD3 LoRA

This commit is contained in:
Artiprocher
2024-07-10 10:07:02 +08:00
parent 8113f95278
commit 979a8814f1
13 changed files with 1030 additions and 32 deletions

View File

@@ -69,7 +69,7 @@ class SD3Prompter(Prompter):
# T5
if text_encoder_3 is None:
prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device)
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
else:
prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device)
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16