add omost.py + omost_flux_example

This commit is contained in:
tc2000731
2024-09-03 19:40:40 +08:00
parent fe485b3fa1
commit 0b066d3cb4
7 changed files with 398 additions and 13 deletions

View File

@@ -25,7 +25,7 @@ class FluxImagePipeline(BasePipeline):
return self.dit
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("flux_dit")
@@ -33,15 +33,16 @@ class FluxImagePipeline(BasePipeline):
self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
self.prompter.load_prompt_extenders(model_manager,prompt_extender_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[],prompt_extender_classes=[]):
pipe = FluxImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes)
pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes)
return pipe
@@ -105,6 +106,14 @@ class FluxImagePipeline(BasePipeline):
else:
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Extend prompt
if len(self.prompter.extenders) > 0:
extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks",[])
mask_scales += [5.0 for _ in range(len(extended_prompt_dict.get("masks",[])))]
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0: