Merge pull request #229 from modelscope/flux-enhance

support t5 sequence length
This commit is contained in:
Zhongjie Duan
2024-09-30 15:33:51 +08:00
committed by GitHub
2 changed files with 9 additions and 7 deletions

View File

@@ -58,9 +58,9 @@ class FluxImagePipeline(BasePipeline):
return image
def encode_prompt(self, prompt, positive=True):
def encode_prompt(self, prompt, positive=True, t5_sequence_length=256):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
@@ -86,6 +86,7 @@ class FluxImagePipeline(BasePipeline):
height=1024,
width=1024,
num_inference_steps=30,
t5_sequence_length=256,
tiled=False,
tile_size=128,
tile_stride=64,
@@ -113,10 +114,10 @@ class FluxImagePipeline(BasePipeline):
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)

View File

@@ -56,7 +56,8 @@ class FluxPrompter(BasePrompter):
self,
prompt,
positive=True,
device="cuda"
device="cuda",
t5_sequence_length=256,
):
prompt = self.process_prompt(prompt, positive=positive)
@@ -64,7 +65,7 @@ class FluxPrompter(BasePrompter):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
# T5
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, 256, device)
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
# text_ids
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)