This commit is contained in:
Artiprocher
2024-09-04 12:48:32 +08:00
parent 0b066d3cb4
commit d70cd04b15
7 changed files with 36 additions and 43 deletions

View File

@@ -50,4 +50,13 @@ class BasePipeline(torch.nn.Module):
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
return noise_pred
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks", [])
mask_scales += [5.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales