add omost.py + omost_flux_example

This commit is contained in:
tc2000731
2024-09-03 19:40:40 +08:00
parent fe485b3fa1
commit 0b066d3cb4
7 changed files with 398 additions and 13 deletions

View File

@@ -37,14 +37,20 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
class BasePrompter:
def __init__(self, refiners=[]):
def __init__(self, refiners=[],extenders = []):
self.refiners = refiners
self.extenders = extenders
def load_prompt_refiners(self, model_nameger: ModelManager, refiner_classes=[]):
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): # manager
for refiner_class in refiner_classes:
refiner = refiner_class.from_model_manager(model_nameger)
refiner = refiner_class.from_model_manager(model_manager)
self.refiners.append(refiner)
def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
for extender_class in extender_classes:
extender = extender_class.from_model_manager(model_manager)
self.extenders.append(extender)
@torch.no_grad()
@@ -55,3 +61,10 @@ class BasePrompter:
for refiner in self.refiners:
prompt = refiner(prompt, positive=positive)
return prompt
@torch.no_grad()
def extend_prompt(self,prompt:str,positive = True):
extended_prompt = dict(prompt=prompt)
for extender in self.extenders:
extended_prompt = extender(extended_prompt)
return extended_prompt