prompt processing

This commit is contained in:
Artiprocher
2024-01-21 22:36:03 +08:00
parent 22328f48ca
commit e076e66827
11 changed files with 134 additions and 75 deletions

View File

@@ -71,3 +71,19 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/8f556355-4079-44
We provide an example for video stylization. In this pipeline, the rendered video is completely different from the original video, thus we need a powerful deflickering algorithm. We use FastBlend to implement the deflickering module. Please see `examples/sd_video_rerender.py` for more details.
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
### Example 7: Prompt Processing
If you are not native English user, we provide translation service for you. Our prompter can translate other language to English and refine it using "BeautifulPrompt" models. Please see `examples/sd_prompt_refining.py` for more details.
Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English.
|seed=0|seed=1|seed=2|seed=3|
|-|-|-|-|
|![0_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/ebb25ca8-7ce1-4d9e-8081-59a867c70c4d)|![1_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/a7e79853-3c1a-471a-9c58-c209ec4b76dd)|![2_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/a292b959-a121-481f-b79c-61cc3346f810)|![3_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/1c19b54e-5a6f-4d48-960b-a7b2b149bb4c)|
Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English. Then the [refining model](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd) will refine the translated prompt for better visual quality.
|seed=0|seed=1|seed=2|seed=3|
|-|-|-|-|
|![0](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/778b1bd9-44e0-46ac-a99c-712b3fc9aaa4)|![1](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/c03479b8-2082-4c6e-8e1c-3582b98686f6)|![2](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edb33d21-3288-4a55-96ca-a4bfe1b50b00)|![3](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/7848cfc1-cad5-4848-8373-41d24e98e584)|

View File

@@ -55,6 +55,10 @@ class ModelManager:
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
return param_name in state_dict
def is_translator(self, state_dict):
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
return param_name in state_dict and len(state_dict) == 254
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDTextEncoder,
@@ -147,6 +151,15 @@ class ModelManager:
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
def load_translator(self, state_dict, file_path=""):
# This model is lightweight, we do not place it on GPU.
component = "translator"
from transformers import AutoModelForSeq2SeqLM
model_folder = os.path.dirname(file_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
self.model[component] = model
self.model_path[component] = file_path
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
@@ -190,6 +203,8 @@ class ModelManager:
self.load_beautiful_prompt(state_dict, file_path=file_path)
elif self.is_RIFE(state_dict):
self.load_RIFE(state_dict, file_path=file_path)
elif self.is_translator(state_dict):
self.load_translator(state_dict, file_path=file_path)
def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list:

View File

@@ -31,8 +31,6 @@ class SDImagePipeline(torch.nn.Module):
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
# load textual inversion
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
@@ -47,9 +45,8 @@ class SDImagePipeline(torch.nn.Module):
self.controlnet = MultiControlNetManager(controlnet_units)
def fetch_beautiful_prompt(self, model_manager: ModelManager):
if "beautiful_prompt" in model_manager.model:
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
@@ -59,7 +56,7 @@ class SDImagePipeline(torch.nn.Module):
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_beautiful_prompt(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
return pipe

View File

@@ -82,8 +82,6 @@ class SDVideoPipeline(torch.nn.Module):
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
# load textual inversion
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
@@ -103,9 +101,8 @@ class SDVideoPipeline(torch.nn.Module):
self.motion_modules = model_manager.motion_modules
def fetch_beautiful_prompt(self, model_manager: ModelManager):
if "beautiful_prompt" in model_manager.model:
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
@@ -117,7 +114,7 @@ class SDVideoPipeline(torch.nn.Module):
)
pipe.fetch_main_models(model_manager)
pipe.fetch_motion_modules(model_manager)
pipe.fetch_beautiful_prompt(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
return pipe

View File

@@ -39,9 +39,8 @@ class SDXLImagePipeline(torch.nn.Module):
pass
def fetch_beautiful_prompt(self, model_manager: ModelManager):
if "beautiful_prompt" in model_manager.model:
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
@@ -51,7 +50,7 @@ class SDXLImagePipeline(torch.nn.Module):
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_beautiful_prompt(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
return pipe
@@ -106,7 +105,8 @@ class SDXLImagePipeline(torch.nn.Module):
self.text_encoder_2,
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device
device=self.device,
positive=True,
)
if cfg_scale != 1.0:
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
@@ -114,7 +114,8 @@ class SDXLImagePipeline(torch.nn.Module):
self.text_encoder_2,
negative_prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device
device=self.device,
positive=False,
)
# Prepare scheduler

View File

@@ -1,5 +1,5 @@
from transformers import CLIPTokenizer, AutoTokenizer
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, ModelManager
import torch, os
@@ -59,33 +59,27 @@ class BeautifulPrompt:
skip_special_tokens=True
)[0].strip()
return prompt
class Translator:
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
class SDPrompter:
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def __call__(self, prompt):
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
output_ids = self.model.generate(input_ids)
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
return prompt
class Prompter:
def __init__(self):
self.tokenizer: CLIPTokenizer = None
self.keyword_dict = {}
self.translator: Translator = None
self.beautiful_prompt: BeautifulPrompt = None
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
# Textual Inversion
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
# Beautiful Prompt
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
def load_textual_inversion(self, textual_inversion_dict):
self.keyword_dict = {}
additional_tokens = []
@@ -105,18 +99,53 @@ or use a number to specify the weight. You should add appropriate words to make
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""
def load_translator(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.translator = Translator(tokenizer_path=model_folder, model=model)
class SDXLPrompter:
def load_from_model_manager(self, model_manager: ModelManager):
self.load_textual_inversion(model_manager.textual_inversion_dict)
if "translator" in model_manager.model:
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
if "beautiful_prompt" in model_manager.model:
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
def process_prompt(self, prompt, positive=True):
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
if positive and self.translator is not None:
prompt = self.translator(prompt)
print(f"Your prompt is translated: \"{prompt}\"")
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
return prompt
class SDPrompter(Prompter):
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
prompt = self.process_prompt(prompt, positive=positive)
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
class SDXLPrompter(Prompter):
def __init__(
self,
tokenizer_path="configs/stable_diffusion/tokenizer",
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
):
# We use the tokenizer implemented by transformers
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
self.keyword_dict = {}
self.beautiful_prompt: BeautifulPrompt = None
def encode_prompt(
self,
@@ -128,15 +157,7 @@ class SDXLPrompter:
positive=True,
device="cuda"
):
# Textual Inversion
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
# Beautiful Prompt
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
prompt = self.process_prompt(prompt, positive=positive)
# 1
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
@@ -153,22 +174,3 @@ class SDXLPrompter:
add_text_embeds = add_text_embeds[0:1]
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return add_text_embeds, prompt_emb
def load_textual_inversion(self, textual_inversion_dict):
self.keyword_dict = {}
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
self.tokenizer.add_tokens(additional_tokens)
def load_beautiful_prompt(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
if model_folder.endswith("v2"):
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""

View File

@@ -0,0 +1,31 @@
from diffsynth import ModelManager, SDXLImagePipeline
import torch
# Download models
# `models/stable_diffusion_xl/sd_xl_base_1.0.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors)
# `models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd/`: [link](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd)
# `models/translator/opus-mt-zh-en/`: [link](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh)
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
model_manager.load_textual_inversions("models/textual_inversion")
model_manager.load_models([
"models/stable_diffusion_xl/sd_xl_base_1.0.safetensors",
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd/model.safetensors",
"models/translator/opus-mt-zh-en/pytorch_model.bin"
])
pipe = SDXLImagePipeline.from_model_manager(model_manager)
prompt = "一个漂亮的女孩"
negative_prompt = ""
for seed in range(4):
torch.manual_seed(seed)
image = pipe(
prompt=prompt, negative_prompt=negative_prompt,
height=1024, width=1024,
num_inference_steps=30
)
image.save(f"{seed}.jpg")

View File

@@ -35,7 +35,7 @@ pipe = SDImagePipeline.from_model_manager(
)
prompt = "masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait,"
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,",
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
torch.manual_seed(0)
image = pipe(

View File

@@ -12,7 +12,7 @@ model_manager.load_models(["models/stable_diffusion_xl/bluePencilXL_v200.safeten
pipe = SDXLImagePipeline.from_model_manager(model_manager)
prompt = "masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait,"
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,",
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
torch.manual_seed(0)
image = pipe(