diff --git a/diffsynth/pipelines/sd3_image.py b/diffsynth/pipelines/sd3_image.py index e6e72e9..9a4109e 100644 --- a/diffsynth/pipelines/sd3_image.py +++ b/diffsynth/pipelines/sd3_image.py @@ -59,9 +59,9 @@ class SD3ImagePipeline(BasePipeline): return image - def encode_prompt(self, prompt, positive=True): + def encode_prompt(self, prompt, positive=True, t5_sequence_length=77): prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt( - prompt, device=self.device, positive=positive + prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length ) return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb} @@ -84,6 +84,7 @@ class SD3ImagePipeline(BasePipeline): height=1024, width=1024, num_inference_steps=20, + t5_sequence_length=77, tiled=False, tile_size=128, tile_stride=64, @@ -109,9 +110,9 @@ class SD3ImagePipeline(BasePipeline): # Encode prompts self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3']) - prompt_emb_posi = self.encode_prompt(prompt, positive=True) - prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) - prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts] + prompt_emb_posi = self.encode_prompt(prompt, positive=True, t5_sequence_length=t5_sequence_length) + prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) + prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] # Denoise self.load_models_to_device(['dit']) diff --git a/diffsynth/prompters/sd3_prompter.py b/diffsynth/prompters/sd3_prompter.py index 73e05b7..ecf9bca 100644 --- a/diffsynth/prompters/sd3_prompter.py +++ b/diffsynth/prompters/sd3_prompter.py @@ -67,7 +67,8 @@ class SD3Prompter(BasePrompter): self, prompt, positive=True, - device="cuda" + device="cuda", + t5_sequence_length=77, ): prompt = self.process_prompt(prompt, positive=positive) @@ -77,9 +78,9 @@ class SD3Prompter(BasePrompter): # T5 if self.text_encoder_3 is None: - prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device) + prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], t5_sequence_length, 4096), dtype=prompt_emb_1.dtype, device=device) else: - prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, 256, device) + prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, t5_sequence_length, device) prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16 # Merge