mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
add omost.py + omost_flux_example
This commit is contained in:
@@ -25,7 +25,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
return self.dit
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]):
|
||||
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
|
||||
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
|
||||
self.dit = model_manager.fetch_model("flux_dit")
|
||||
@@ -33,15 +33,16 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
self.prompter.load_prompt_extenders(model_manager,prompt_extender_classes)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[],prompt_extender_classes=[]):
|
||||
pipe = FluxImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -105,6 +106,14 @@ class FluxImagePipeline(BasePipeline):
|
||||
else:
|
||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Extend prompt
|
||||
if len(self.prompter.extenders) > 0:
|
||||
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
||||
prompt = extended_prompt_dict.get("prompt", prompt)
|
||||
local_prompts += extended_prompt_dict.get("prompts", [])
|
||||
masks += extended_prompt_dict.get("masks",[])
|
||||
mask_scales += [5.0 for _ in range(len(extended_prompt_dict.get("masks",[])))]
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
|
||||
Reference in New Issue
Block a user