support t5 sequence length

This commit is contained in:
Artiprocher
2024-09-30 14:45:30 +08:00
parent d91c603875
commit c414f4cb12
2 changed files with 9 additions and 7 deletions

View File

@@ -56,7 +56,8 @@ class FluxPrompter(BasePrompter):
self,
prompt,
positive=True,
device="cuda"
device="cuda",
t5_sequence_length=256,
):
prompt = self.process_prompt(prompt, positive=positive)
@@ -64,7 +65,7 @@ class FluxPrompter(BasePrompter):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
# T5
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, 256, device)
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
# text_ids
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)