mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
rebuild base modules
This commit is contained in:
77
diffsynth/prompters/prompt_refiners.py
Normal file
77
diffsynth/prompters/prompt_refiners.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from transformers import AutoTokenizer
|
||||
from ..models.model_manager import ModelManager
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class BeautifulPrompt(torch.nn.Module):
|
||||
def __init__(self, tokenizer_path=None, model=None, template=""):
|
||||
super().__init__()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
self.template = template
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_nameger: ModelManager):
|
||||
model, model_path = model_nameger.fetch_model("beautiful_prompt", require_model_path=True)
|
||||
template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||
if model_path.endswith("v2"):
|
||||
template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
beautiful_prompt = BeautifulPrompt(
|
||||
tokenizer_path=model_path,
|
||||
model=model,
|
||||
template=template
|
||||
)
|
||||
return beautiful_prompt
|
||||
|
||||
|
||||
def __call__(self, raw_prompt, positive=True, **kwargs):
|
||||
if positive:
|
||||
model_input = self.template.format(raw_prompt=raw_prompt)
|
||||
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=384,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1
|
||||
)
|
||||
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
|
||||
outputs[:, input_ids.size(1):],
|
||||
skip_special_tokens=True
|
||||
)[0].strip()
|
||||
print(f"Your prompt is refined by BeautifulPrompt: {prompt}")
|
||||
return prompt
|
||||
else:
|
||||
return raw_prompt
|
||||
|
||||
|
||||
|
||||
class Translator(torch.nn.Module):
|
||||
def __init__(self, tokenizer_path=None, model=None):
|
||||
super().__init__()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_nameger: ModelManager):
|
||||
model, model_path = model_nameger.fetch_model("translator", require_model_path=True)
|
||||
translator = Translator(tokenizer_path=model_path, model=model)
|
||||
return translator
|
||||
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
|
||||
output_ids = self.model.generate(input_ids)
|
||||
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
||||
print(f"Your prompt is translated: {prompt}")
|
||||
return prompt
|
||||
Reference in New Issue
Block a user