mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
Merge pull request #260 from mi804/sd3.5
update default t5_sequence_length to 77
This commit is contained in:
@@ -59,9 +59,9 @@ class SD3ImagePipeline(BasePipeline):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True):
|
def encode_prompt(self, prompt, positive=True, t5_sequence_length=77):
|
||||||
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
|
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
|
||||||
prompt, device=self.device, positive=positive
|
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
||||||
)
|
)
|
||||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
|
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
|
||||||
|
|
||||||
@@ -84,6 +84,7 @@ class SD3ImagePipeline(BasePipeline):
|
|||||||
height=1024,
|
height=1024,
|
||||||
width=1024,
|
width=1024,
|
||||||
num_inference_steps=20,
|
num_inference_steps=20,
|
||||||
|
t5_sequence_length=77,
|
||||||
tiled=False,
|
tiled=False,
|
||||||
tile_size=128,
|
tile_size=128,
|
||||||
tile_stride=64,
|
tile_stride=64,
|
||||||
@@ -109,9 +110,9 @@ class SD3ImagePipeline(BasePipeline):
|
|||||||
|
|
||||||
# Encode prompts
|
# Encode prompts
|
||||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])
|
self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
prompt_emb_posi = self.encode_prompt(prompt, positive=True, t5_sequence_length=t5_sequence_length)
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
|
||||||
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
|
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device(['dit'])
|
self.load_models_to_device(['dit'])
|
||||||
|
|||||||
@@ -67,7 +67,8 @@ class SD3Prompter(BasePrompter):
|
|||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
positive=True,
|
positive=True,
|
||||||
device="cuda"
|
device="cuda",
|
||||||
|
t5_sequence_length=77,
|
||||||
):
|
):
|
||||||
prompt = self.process_prompt(prompt, positive=positive)
|
prompt = self.process_prompt(prompt, positive=positive)
|
||||||
|
|
||||||
@@ -77,9 +78,9 @@ class SD3Prompter(BasePrompter):
|
|||||||
|
|
||||||
# T5
|
# T5
|
||||||
if self.text_encoder_3 is None:
|
if self.text_encoder_3 is None:
|
||||||
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
|
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], t5_sequence_length, 4096), dtype=prompt_emb_1.dtype, device=device)
|
||||||
else:
|
else:
|
||||||
prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, 256, device)
|
prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, t5_sequence_length, device)
|
||||||
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
|
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
|
||||||
|
|
||||||
# Merge
|
# Merge
|
||||||
|
|||||||
Reference in New Issue
Block a user