From 518c6d6ac33efd9d2dd60ab30d996f6eac4a104b Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 5 Jul 2024 13:36:54 +0800 Subject: [PATCH] support SD3 textual inversion --- diffsynth/models/__init__.py | 20 +++++++++--- diffsynth/models/sd3_text_encoder.py | 4 +-- diffsynth/prompts/sd3_prompter.py | 10 +++--- diffsynth/prompts/utils.py | 17 ++++++++-- .../sd3_text_to_image_textual_inversion.py | 32 +++++++++++++++++++ 5 files changed, 70 insertions(+), 13 deletions(-) create mode 100644 examples/image_synthesis/sd3_text_to_image_textual_inversion.py diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index 7f1917c..ad6befa 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -567,10 +567,22 @@ class ModelManager: if component == "sd3_text_encoder_3": if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict: continue - self.model[component] = component_dict[component]() - self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) - self.model[component].to(self.torch_dtype).to(self.device) - self.model_path[component] = file_path + elif component == "sd3_text_encoder_1": + # Add additional token embeddings to text encoder + token_embeddings = [state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"]] + for keyword in self.textual_inversion_dict: + _, embeddings = self.textual_inversion_dict[keyword] + token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype)) + token_embeddings = torch.concat(token_embeddings, dim=0) + state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings + self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0]) + self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) + self.model[component].to(self.torch_dtype).to(self.device) + else: + self.model[component] = component_dict[component]() + self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) + self.model[component].to(self.torch_dtype).to(self.device) + self.model_path[component] = file_path def load_stable_diffusion_3_t5(self, state_dict, file_path=""): component = "sd3_text_encoder_3" diff --git a/diffsynth/models/sd3_text_encoder.py b/diffsynth/models/sd3_text_encoder.py index 287cb38..bb0fc6d 100644 --- a/diffsynth/models/sd3_text_encoder.py +++ b/diffsynth/models/sd3_text_encoder.py @@ -5,8 +5,8 @@ from .sdxl_text_encoder import SDXLTextEncoder2, SDXLTextEncoder2StateDictConver class SD3TextEncoder1(SDTextEncoder): - def __init__(self): - super().__init__() + def __init__(self, vocab_size=49408): + super().__init__(vocab_size=vocab_size) def forward(self, input_ids, clip_skip=2): embeds = self.token_embedding(input_ids) + self.position_embeds diff --git a/diffsynth/prompts/sd3_prompter.py b/diffsynth/prompts/sd3_prompter.py index 060457a..5bc252e 100644 --- a/diffsynth/prompts/sd3_prompter.py +++ b/diffsynth/prompts/sd3_prompter.py @@ -20,7 +20,7 @@ class SD3Prompter(Prompter): base_path = os.path.dirname(os.path.dirname(__file__)) tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3") super().__init__() - self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path) + self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_1_path) self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path) self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path) @@ -61,17 +61,17 @@ class SD3Prompter(Prompter): positive=True, device="cuda" ): - prompt = self.process_prompt(prompt, positive=positive) + prompt, pure_prompt = self.process_prompt(prompt, positive=positive, require_pure_prompt=True) # CLIP - pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer_1, 77, device) - pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, text_encoder_2, self.tokenizer_2, 77, device) + pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer, 77, device) + pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(pure_prompt, text_encoder_2, self.tokenizer_2, 77, device) # T5 if text_encoder_3 is None: prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device) else: - prompt_emb_3 = self.encode_prompt_using_t5(prompt, text_encoder_3, self.tokenizer_3, 256, device) + prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device) prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16 # Merge diff --git a/diffsynth/prompts/utils.py b/diffsynth/prompts/utils.py index 291dc5a..45a879d 100644 --- a/diffsynth/prompts/utils.py +++ b/diffsynth/prompts/utils.py @@ -111,14 +111,27 @@ but make sure there is a correlation between the input and output.\n\ if "beautiful_prompt" in model_manager.model: self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"]) - def process_prompt(self, prompt, positive=True): + def add_textual_inversion_tokens(self, prompt): for keyword in self.keyword_dict: if keyword in prompt: prompt = prompt.replace(keyword, self.keyword_dict[keyword]) + return prompt + + def del_textual_inversion_tokens(self, prompt): + for keyword in self.keyword_dict: + if keyword in prompt: + prompt = prompt.replace(keyword, "") + return prompt + + def process_prompt(self, prompt, positive=True, require_pure_prompt=False): + prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt) if positive and self.translator is not None: prompt = self.translator(prompt) print(f"Your prompt is translated: \"{prompt}\"") if positive and self.beautiful_prompt is not None: prompt = self.beautiful_prompt(prompt) print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"") - return prompt + if require_pure_prompt: + return prompt, pure_prompt + else: + return prompt diff --git a/examples/image_synthesis/sd3_text_to_image_textual_inversion.py b/examples/image_synthesis/sd3_text_to_image_textual_inversion.py new file mode 100644 index 0000000..1cf5256 --- /dev/null +++ b/examples/image_synthesis/sd3_text_to_image_textual_inversion.py @@ -0,0 +1,32 @@ +from diffsynth import ModelManager, SD3ImagePipeline, download_models, load_state_dict +import torch + + +# Download models (automatically) +# `models/stable_diffusion_3/sd3_medium_incl_clips.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/sd3_medium_incl_clips.safetensors) +# `models/textual_inversion/verybadimagenegative_v1.3.pt`: [link](https://civitai.com/api/download/models/25820?type=Model&format=PickleTensor&size=full&fp=fp16) +download_models(["StableDiffusion3_without_T5", "TextualInversion_VeryBadImageNegative_v1.3"]) +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") +model_manager.load_textual_inversions("models/textual_inversion") +model_manager.load_models(["models/stable_diffusion_3/sd3_medium_incl_clips.safetensors"]) +pipe = SD3ImagePipeline.from_model_manager(model_manager) + + +for seed in range(4): + torch.manual_seed(seed) + image = pipe( + prompt="a girl, highly detailed, absurd res, perfect image", + negative_prompt="verybadimagenegative_v1.3", + cfg_scale=4.5, + num_inference_steps=50, width=1024, height=1024, + ) + image.save(f"image_with_textual_inversion_{seed}.jpg") + + torch.manual_seed(seed) + image = pipe( + prompt="a girl, highly detailed, absurd res, perfect image", + negative_prompt="", + cfg_scale=4.5, + num_inference_steps=50, width=1024, height=1024, + ) + image.save(f"image_without_textual_inversion_{seed}.jpg")