This commit is contained in:
Artiprocher
2024-09-04 12:48:32 +08:00
parent 0b066d3cb4
commit d70cd04b15
7 changed files with 36 additions and 43 deletions

View File

@@ -37,12 +37,12 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
class BasePrompter:
def __init__(self, refiners=[],extenders = []):
def __init__(self, refiners=[], extenders=[]):
self.refiners = refiners
self.extenders = extenders
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): # manager
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
for refiner_class in refiner_classes:
refiner = refiner_class.from_model_manager(model_manager)
self.refiners.append(refiner)
@@ -63,7 +63,7 @@ class BasePrompter:
return prompt
@torch.no_grad()
def extend_prompt(self,prompt:str,positive = True):
def extend_prompt(self, prompt:str, positive=True):
extended_prompt = dict(prompt=prompt)
for extender in self.extenders:
extended_prompt = extender(extended_prompt)