From e076e668271cdb8d4e2884eda02fa49fc6ec5c78 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Sun, 21 Jan 2024 22:36:03 +0800 Subject: [PATCH] prompt processing --- README.md | 16 +++ diffsynth/models/__init__.py | 15 +++ diffsynth/pipelines/stable_diffusion.py | 9 +- diffsynth/pipelines/stable_diffusion_video.py | 9 +- diffsynth/pipelines/stable_diffusion_xl.py | 13 +- diffsynth/prompts/__init__.py | 112 +++++++++--------- examples/sd_prompt_refining.py | 31 +++++ examples/sd_text_to_image.py | 2 +- examples/sdxl_text_to_image.py | 2 +- .../Put BeautifulPrompt models here.txt | 0 .../translator/Put translator models here.txt | 0 11 files changed, 134 insertions(+), 75 deletions(-) create mode 100644 examples/sd_prompt_refining.py create mode 100644 models/BeautifulPrompt/Put BeautifulPrompt models here.txt create mode 100644 models/translator/Put translator models here.txt diff --git a/README.md b/README.md index 59df0b9..28b49f2 100644 --- a/README.md +++ b/README.md @@ -71,3 +71,19 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/8f556355-4079-44 We provide an example for video stylization. In this pipeline, the rendered video is completely different from the original video, thus we need a powerful deflickering algorithm. We use FastBlend to implement the deflickering module. Please see `examples/sd_video_rerender.py` for more details. https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea + +### Example 7: Prompt Processing + +If you are not native English user, we provide translation service for you. Our prompter can translate other language to English and refine it using "BeautifulPrompt" models. Please see `examples/sd_prompt_refining.py` for more details. + +Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English. + +|seed=0|seed=1|seed=2|seed=3| +|-|-|-|-| +|![0_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/ebb25ca8-7ce1-4d9e-8081-59a867c70c4d)|![1_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/a7e79853-3c1a-471a-9c58-c209ec4b76dd)|![2_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/a292b959-a121-481f-b79c-61cc3346f810)|![3_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/1c19b54e-5a6f-4d48-960b-a7b2b149bb4c)| + +Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English. Then the [refining model](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd) will refine the translated prompt for better visual quality. + +|seed=0|seed=1|seed=2|seed=3| +|-|-|-|-| +|![0](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/778b1bd9-44e0-46ac-a99c-712b3fc9aaa4)|![1](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/c03479b8-2082-4c6e-8e1c-3582b98686f6)|![2](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edb33d21-3288-4a55-96ca-a4bfe1b50b00)|![3](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/7848cfc1-cad5-4848-8373-41d24e98e584)| diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index 35c6472..99b113f 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -55,6 +55,10 @@ class ModelManager: param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight" return param_name in state_dict + def is_translator(self, state_dict): + param_name = "model.encoder.layers.5.self_attn_layer_norm.weight" + return param_name in state_dict and len(state_dict) == 254 + def load_stable_diffusion(self, state_dict, components=None, file_path=""): component_dict = { "text_encoder": SDTextEncoder, @@ -147,6 +151,15 @@ class ModelManager: SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device) SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device) + def load_translator(self, state_dict, file_path=""): + # This model is lightweight, we do not place it on GPU. + component = "translator" + from transformers import AutoModelForSeq2SeqLM + model_folder = os.path.dirname(file_path) + model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval() + self.model[component] = model + self.model_path[component] = file_path + def search_for_embeddings(self, state_dict): embeddings = [] for k in state_dict: @@ -190,6 +203,8 @@ class ModelManager: self.load_beautiful_prompt(state_dict, file_path=file_path) elif self.is_RIFE(state_dict): self.load_RIFE(state_dict, file_path=file_path) + elif self.is_translator(state_dict): + self.load_translator(state_dict, file_path=file_path) def load_models(self, file_path_list, lora_alphas=[]): for file_path in file_path_list: diff --git a/diffsynth/pipelines/stable_diffusion.py b/diffsynth/pipelines/stable_diffusion.py index d82d480..1f0faf4 100644 --- a/diffsynth/pipelines/stable_diffusion.py +++ b/diffsynth/pipelines/stable_diffusion.py @@ -31,8 +31,6 @@ class SDImagePipeline(torch.nn.Module): self.unet = model_manager.unet self.vae_decoder = model_manager.vae_decoder self.vae_encoder = model_manager.vae_encoder - # load textual inversion - self.prompter.load_textual_inversion(model_manager.textual_inversion_dict) def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): @@ -47,9 +45,8 @@ class SDImagePipeline(torch.nn.Module): self.controlnet = MultiControlNetManager(controlnet_units) - def fetch_beautiful_prompt(self, model_manager: ModelManager): - if "beautiful_prompt" in model_manager.model: - self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"]) + def fetch_prompter(self, model_manager: ModelManager): + self.prompter.load_from_model_manager(model_manager) @staticmethod @@ -59,7 +56,7 @@ class SDImagePipeline(torch.nn.Module): torch_dtype=model_manager.torch_dtype, ) pipe.fetch_main_models(model_manager) - pipe.fetch_beautiful_prompt(model_manager) + pipe.fetch_prompter(model_manager) pipe.fetch_controlnet_models(model_manager, controlnet_config_units) return pipe diff --git a/diffsynth/pipelines/stable_diffusion_video.py b/diffsynth/pipelines/stable_diffusion_video.py index 7dd35a3..36c3127 100644 --- a/diffsynth/pipelines/stable_diffusion_video.py +++ b/diffsynth/pipelines/stable_diffusion_video.py @@ -82,8 +82,6 @@ class SDVideoPipeline(torch.nn.Module): self.unet = model_manager.unet self.vae_decoder = model_manager.vae_decoder self.vae_encoder = model_manager.vae_encoder - # load textual inversion - self.prompter.load_textual_inversion(model_manager.textual_inversion_dict) def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): @@ -103,9 +101,8 @@ class SDVideoPipeline(torch.nn.Module): self.motion_modules = model_manager.motion_modules - def fetch_beautiful_prompt(self, model_manager: ModelManager): - if "beautiful_prompt" in model_manager.model: - self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"]) + def fetch_prompter(self, model_manager: ModelManager): + self.prompter.load_from_model_manager(model_manager) @staticmethod @@ -117,7 +114,7 @@ class SDVideoPipeline(torch.nn.Module): ) pipe.fetch_main_models(model_manager) pipe.fetch_motion_modules(model_manager) - pipe.fetch_beautiful_prompt(model_manager) + pipe.fetch_prompter(model_manager) pipe.fetch_controlnet_models(model_manager, controlnet_config_units) return pipe diff --git a/diffsynth/pipelines/stable_diffusion_xl.py b/diffsynth/pipelines/stable_diffusion_xl.py index f9dd481..0fec886 100644 --- a/diffsynth/pipelines/stable_diffusion_xl.py +++ b/diffsynth/pipelines/stable_diffusion_xl.py @@ -39,9 +39,8 @@ class SDXLImagePipeline(torch.nn.Module): pass - def fetch_beautiful_prompt(self, model_manager: ModelManager): - if "beautiful_prompt" in model_manager.model: - self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"]) + def fetch_prompter(self, model_manager: ModelManager): + self.prompter.load_from_model_manager(model_manager) @staticmethod @@ -51,7 +50,7 @@ class SDXLImagePipeline(torch.nn.Module): torch_dtype=model_manager.torch_dtype, ) pipe.fetch_main_models(model_manager) - pipe.fetch_beautiful_prompt(model_manager) + pipe.fetch_prompter(model_manager) pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units) return pipe @@ -106,7 +105,8 @@ class SDXLImagePipeline(torch.nn.Module): self.text_encoder_2, prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, - device=self.device + device=self.device, + positive=True, ) if cfg_scale != 1.0: add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt( @@ -114,7 +114,8 @@ class SDXLImagePipeline(torch.nn.Module): self.text_encoder_2, negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, - device=self.device + device=self.device, + positive=False, ) # Prepare scheduler diff --git a/diffsynth/prompts/__init__.py b/diffsynth/prompts/__init__.py index be94f8b..464cfe8 100644 --- a/diffsynth/prompts/__init__.py +++ b/diffsynth/prompts/__init__.py @@ -1,5 +1,5 @@ from transformers import CLIPTokenizer, AutoTokenizer -from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2 +from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, ModelManager import torch, os @@ -59,33 +59,27 @@ class BeautifulPrompt: skip_special_tokens=True )[0].strip() return prompt + +class Translator: + def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.model = model -class SDPrompter: - def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"): - # We use the tokenizer implemented by transformers - self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) + def __call__(self, prompt): + input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device) + output_ids = self.model.generate(input_ids) + prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] + return prompt + + +class Prompter: + def __init__(self): + self.tokenizer: CLIPTokenizer = None self.keyword_dict = {} + self.translator: Translator = None self.beautiful_prompt: BeautifulPrompt = None - - def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True): - # Textual Inversion - for keyword in self.keyword_dict: - if keyword in prompt: - prompt = prompt.replace(keyword, self.keyword_dict[keyword]) - - # Beautiful Prompt - if positive and self.beautiful_prompt is not None: - prompt = self.beautiful_prompt(prompt) - print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"") - - input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) - prompt_emb = text_encoder(input_ids, clip_skip=clip_skip) - prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) - - return prompt_emb - def load_textual_inversion(self, textual_inversion_dict): self.keyword_dict = {} additional_tokens = [] @@ -105,18 +99,53 @@ or use a number to specify the weight. You should add appropriate words to make but make sure there is a correlation between the input and output.\n\ ### Input: {raw_prompt}\n### Output:""" + def load_translator(self, model, model_path): + model_folder = os.path.dirname(model_path) + self.translator = Translator(tokenizer_path=model_folder, model=model) -class SDXLPrompter: + def load_from_model_manager(self, model_manager: ModelManager): + self.load_textual_inversion(model_manager.textual_inversion_dict) + if "translator" in model_manager.model: + self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"]) + if "beautiful_prompt" in model_manager.model: + self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"]) + + def process_prompt(self, prompt, positive=True): + for keyword in self.keyword_dict: + if keyword in prompt: + prompt = prompt.replace(keyword, self.keyword_dict[keyword]) + if positive and self.translator is not None: + prompt = self.translator(prompt) + print(f"Your prompt is translated: \"{prompt}\"") + if positive and self.beautiful_prompt is not None: + prompt = self.beautiful_prompt(prompt) + print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"") + return prompt + + +class SDPrompter(Prompter): + def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) + + def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True): + prompt = self.process_prompt(prompt, positive=positive) + input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) + prompt_emb = text_encoder(input_ids, clip_skip=clip_skip) + prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) + + return prompt_emb + + +class SDXLPrompter(Prompter): def __init__( self, tokenizer_path="configs/stable_diffusion/tokenizer", tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2" ): - # We use the tokenizer implemented by transformers + super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path) - self.keyword_dict = {} - self.beautiful_prompt: BeautifulPrompt = None def encode_prompt( self, @@ -128,15 +157,7 @@ class SDXLPrompter: positive=True, device="cuda" ): - # Textual Inversion - for keyword in self.keyword_dict: - if keyword in prompt: - prompt = prompt.replace(keyword, self.keyword_dict[keyword]) - - # Beautiful Prompt - if positive and self.beautiful_prompt is not None: - prompt = self.beautiful_prompt(prompt) - print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"") + prompt = self.process_prompt(prompt, positive=positive) # 1 input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) @@ -153,22 +174,3 @@ class SDXLPrompter: add_text_embeds = add_text_embeds[0:1] prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) return add_text_embeds, prompt_emb - - def load_textual_inversion(self, textual_inversion_dict): - self.keyword_dict = {} - additional_tokens = [] - for keyword in textual_inversion_dict: - tokens, _ = textual_inversion_dict[keyword] - additional_tokens += tokens - self.keyword_dict[keyword] = " " + " ".join(tokens) + " " - self.tokenizer.add_tokens(additional_tokens) - - def load_beautiful_prompt(self, model, model_path): - model_folder = os.path.dirname(model_path) - self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model) - if model_folder.endswith("v2"): - self.beautiful_prompt.template = """Converts a simple image description into a prompt. \ -Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \ -or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \ -but make sure there is a correlation between the input and output.\n\ -### Input: {raw_prompt}\n### Output:""" diff --git a/examples/sd_prompt_refining.py b/examples/sd_prompt_refining.py new file mode 100644 index 0000000..bd76804 --- /dev/null +++ b/examples/sd_prompt_refining.py @@ -0,0 +1,31 @@ +from diffsynth import ModelManager, SDXLImagePipeline +import torch + + +# Download models +# `models/stable_diffusion_xl/sd_xl_base_1.0.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors) +# `models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd/`: [link](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd) +# `models/translator/opus-mt-zh-en/`: [link](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) + + +# Load models +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") +model_manager.load_textual_inversions("models/textual_inversion") +model_manager.load_models([ + "models/stable_diffusion_xl/sd_xl_base_1.0.safetensors", + "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd/model.safetensors", + "models/translator/opus-mt-zh-en/pytorch_model.bin" +]) +pipe = SDXLImagePipeline.from_model_manager(model_manager) + +prompt = "一个漂亮的女孩" +negative_prompt = "" + +for seed in range(4): + torch.manual_seed(seed) + image = pipe( + prompt=prompt, negative_prompt=negative_prompt, + height=1024, width=1024, + num_inference_steps=30 + ) + image.save(f"{seed}.jpg") diff --git a/examples/sd_text_to_image.py b/examples/sd_text_to_image.py index 76e65d3..5f32bcd 100644 --- a/examples/sd_text_to_image.py +++ b/examples/sd_text_to_image.py @@ -35,7 +35,7 @@ pipe = SDImagePipeline.from_model_manager( ) prompt = "masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait," -negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,", +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," torch.manual_seed(0) image = pipe( diff --git a/examples/sdxl_text_to_image.py b/examples/sdxl_text_to_image.py index 16df873..3d3e362 100644 --- a/examples/sdxl_text_to_image.py +++ b/examples/sdxl_text_to_image.py @@ -12,7 +12,7 @@ model_manager.load_models(["models/stable_diffusion_xl/bluePencilXL_v200.safeten pipe = SDXLImagePipeline.from_model_manager(model_manager) prompt = "masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait," -negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,", +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," torch.manual_seed(0) image = pipe( diff --git a/models/BeautifulPrompt/Put BeautifulPrompt models here.txt b/models/BeautifulPrompt/Put BeautifulPrompt models here.txt new file mode 100644 index 0000000..e69de29 diff --git a/models/translator/Put translator models here.txt b/models/translator/Put translator models here.txt new file mode 100644 index 0000000..e69de29