From c414f4cb129721b5dcaaa3c76b2d3faa036806f9 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 30 Sep 2024 14:45:30 +0800 Subject: [PATCH] support t5 sequence length --- diffsynth/pipelines/flux_image.py | 11 ++++++----- diffsynth/prompters/flux_prompter.py | 5 +++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index d014a92..5cd57f1 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -58,9 +58,9 @@ class FluxImagePipeline(BasePipeline): return image - def encode_prompt(self, prompt, positive=True): + def encode_prompt(self, prompt, positive=True, t5_sequence_length=256): prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt( - prompt, device=self.device, positive=positive + prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length ) return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} @@ -86,6 +86,7 @@ class FluxImagePipeline(BasePipeline): height=1024, width=1024, num_inference_steps=30, + t5_sequence_length=256, tiled=False, tile_size=128, tile_stride=64, @@ -113,10 +114,10 @@ class FluxImagePipeline(BasePipeline): prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales) # Encode prompts - prompt_emb_posi = self.encode_prompt(prompt, positive=True) + prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0: - prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) - prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts] + prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) + prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] # Extra input extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) diff --git a/diffsynth/prompters/flux_prompter.py b/diffsynth/prompters/flux_prompter.py index 656f43b..07d6dc7 100644 --- a/diffsynth/prompters/flux_prompter.py +++ b/diffsynth/prompters/flux_prompter.py @@ -56,7 +56,8 @@ class FluxPrompter(BasePrompter): self, prompt, positive=True, - device="cuda" + device="cuda", + t5_sequence_length=256, ): prompt = self.process_prompt(prompt, positive=positive) @@ -64,7 +65,7 @@ class FluxPrompter(BasePrompter): pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device) # T5 - prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, 256, device) + prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device) # text_ids text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)