mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
prompt processing
This commit is contained in:
16
README.md
16
README.md
@@ -71,3 +71,19 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/8f556355-4079-44
|
||||
We provide an example for video stylization. In this pipeline, the rendered video is completely different from the original video, thus we need a powerful deflickering algorithm. We use FastBlend to implement the deflickering module. Please see `examples/sd_video_rerender.py` for more details.
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
### Example 7: Prompt Processing
|
||||
|
||||
If you are not native English user, we provide translation service for you. Our prompter can translate other language to English and refine it using "BeautifulPrompt" models. Please see `examples/sd_prompt_refining.py` for more details.
|
||||
|
||||
Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English.
|
||||
|
||||
|seed=0|seed=1|seed=2|seed=3|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English. Then the [refining model](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd) will refine the translated prompt for better visual quality.
|
||||
|
||||
|seed=0|seed=1|seed=2|seed=3|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
@@ -55,6 +55,10 @@ class ModelManager:
|
||||
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_translator(self, state_dict):
|
||||
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
||||
return param_name in state_dict and len(state_dict) == 254
|
||||
|
||||
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"text_encoder": SDTextEncoder,
|
||||
@@ -147,6 +151,15 @@ class ModelManager:
|
||||
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
|
||||
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
|
||||
|
||||
def load_translator(self, state_dict, file_path=""):
|
||||
# This model is lightweight, we do not place it on GPU.
|
||||
component = "translator"
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
model_folder = os.path.dirname(file_path)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -190,6 +203,8 @@ class ModelManager:
|
||||
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
||||
elif self.is_RIFE(state_dict):
|
||||
self.load_RIFE(state_dict, file_path=file_path)
|
||||
elif self.is_translator(state_dict):
|
||||
self.load_translator(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
|
||||
@@ -31,8 +31,6 @@ class SDImagePipeline(torch.nn.Module):
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
# load textual inversion
|
||||
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
@@ -47,9 +45,8 @@ class SDImagePipeline(torch.nn.Module):
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
|
||||
def fetch_beautiful_prompt(self, model_manager: ModelManager):
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -59,7 +56,7 @@ class SDImagePipeline(torch.nn.Module):
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_beautiful_prompt(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -82,8 +82,6 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
# load textual inversion
|
||||
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
@@ -103,9 +101,8 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
self.motion_modules = model_manager.motion_modules
|
||||
|
||||
|
||||
def fetch_beautiful_prompt(self, model_manager: ModelManager):
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -117,7 +114,7 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_motion_modules(model_manager)
|
||||
pipe.fetch_beautiful_prompt(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -39,9 +39,8 @@ class SDXLImagePipeline(torch.nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
def fetch_beautiful_prompt(self, model_manager: ModelManager):
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -51,7 +50,7 @@ class SDXLImagePipeline(torch.nn.Module):
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_beautiful_prompt(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
@@ -106,7 +105,8 @@ class SDXLImagePipeline(torch.nn.Module):
|
||||
self.text_encoder_2,
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device
|
||||
device=self.device,
|
||||
positive=True,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||
@@ -114,7 +114,8 @@ class SDXLImagePipeline(torch.nn.Module):
|
||||
self.text_encoder_2,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device
|
||||
device=self.device,
|
||||
positive=False,
|
||||
)
|
||||
|
||||
# Prepare scheduler
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from transformers import CLIPTokenizer, AutoTokenizer
|
||||
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2
|
||||
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, ModelManager
|
||||
import torch, os
|
||||
|
||||
|
||||
@@ -59,33 +59,27 @@ class BeautifulPrompt:
|
||||
skip_special_tokens=True
|
||||
)[0].strip()
|
||||
return prompt
|
||||
|
||||
|
||||
class Translator:
|
||||
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
|
||||
class SDPrompter:
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
def __call__(self, prompt):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
|
||||
output_ids = self.model.generate(input_ids)
|
||||
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
||||
return prompt
|
||||
|
||||
|
||||
class Prompter:
|
||||
def __init__(self):
|
||||
self.tokenizer: CLIPTokenizer = None
|
||||
self.keyword_dict = {}
|
||||
self.translator: Translator = None
|
||||
self.beautiful_prompt: BeautifulPrompt = None
|
||||
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
|
||||
# Textual Inversion
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
|
||||
# Beautiful Prompt
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
|
||||
def load_textual_inversion(self, textual_inversion_dict):
|
||||
self.keyword_dict = {}
|
||||
additional_tokens = []
|
||||
@@ -105,18 +99,53 @@ or use a number to specify the weight. You should add appropriate words to make
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
|
||||
def load_translator(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.translator = Translator(tokenizer_path=model_folder, model=model)
|
||||
|
||||
class SDXLPrompter:
|
||||
def load_from_model_manager(self, model_manager: ModelManager):
|
||||
self.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
if "translator" in model_manager.model:
|
||||
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
def process_prompt(self, prompt, positive=True):
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
if positive and self.translator is not None:
|
||||
prompt = self.translator(prompt)
|
||||
print(f"Your prompt is translated: \"{prompt}\"")
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
return prompt
|
||||
|
||||
|
||||
class SDPrompter(Prompter):
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
|
||||
|
||||
class SDXLPrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path="configs/stable_diffusion/tokenizer",
|
||||
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
|
||||
):
|
||||
# We use the tokenizer implemented by transformers
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
self.keyword_dict = {}
|
||||
self.beautiful_prompt: BeautifulPrompt = None
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
@@ -128,15 +157,7 @@ class SDXLPrompter:
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
# Textual Inversion
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
|
||||
# Beautiful Prompt
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# 1
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
@@ -153,22 +174,3 @@ class SDXLPrompter:
|
||||
add_text_embeds = add_text_embeds[0:1]
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
return add_text_embeds, prompt_emb
|
||||
|
||||
def load_textual_inversion(self, textual_inversion_dict):
|
||||
self.keyword_dict = {}
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
self.tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
def load_beautiful_prompt(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
|
||||
if model_folder.endswith("v2"):
|
||||
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
|
||||
31
examples/sd_prompt_refining.py
Normal file
31
examples/sd_prompt_refining.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from diffsynth import ModelManager, SDXLImagePipeline
|
||||
import torch
|
||||
|
||||
|
||||
# Download models
|
||||
# `models/stable_diffusion_xl/sd_xl_base_1.0.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors)
|
||||
# `models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd/`: [link](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd)
|
||||
# `models/translator/opus-mt-zh-en/`: [link](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh)
|
||||
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_textual_inversions("models/textual_inversion")
|
||||
model_manager.load_models([
|
||||
"models/stable_diffusion_xl/sd_xl_base_1.0.safetensors",
|
||||
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd/model.safetensors",
|
||||
"models/translator/opus-mt-zh-en/pytorch_model.bin"
|
||||
])
|
||||
pipe = SDXLImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
prompt = "一个漂亮的女孩"
|
||||
negative_prompt = ""
|
||||
|
||||
for seed in range(4):
|
||||
torch.manual_seed(seed)
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
height=1024, width=1024,
|
||||
num_inference_steps=30
|
||||
)
|
||||
image.save(f"{seed}.jpg")
|
||||
@@ -35,7 +35,7 @@ pipe = SDImagePipeline.from_model_manager(
|
||||
)
|
||||
|
||||
prompt = "masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait,"
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,",
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
|
||||
@@ -12,7 +12,7 @@ model_manager.load_models(["models/stable_diffusion_xl/bluePencilXL_v200.safeten
|
||||
pipe = SDXLImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
prompt = "masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait,"
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,",
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
|
||||
0
models/translator/Put translator models here.txt
Normal file
0
models/translator/Put translator models here.txt
Normal file
Reference in New Issue
Block a user