mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:32:27 +00:00
support SD3 textual inversion
This commit is contained in:
@@ -567,10 +567,22 @@ class ModelManager:
|
||||
if component == "sd3_text_encoder_3":
|
||||
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
|
||||
continue
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
elif component == "sd3_text_encoder_1":
|
||||
# Add additional token embeddings to text encoder
|
||||
token_embeddings = [state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"]]
|
||||
for keyword in self.textual_inversion_dict:
|
||||
_, embeddings = self.textual_inversion_dict[keyword]
|
||||
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
|
||||
token_embeddings = torch.concat(token_embeddings, dim=0)
|
||||
state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
|
||||
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
else:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_stable_diffusion_3_t5(self, state_dict, file_path=""):
|
||||
component = "sd3_text_encoder_3"
|
||||
|
||||
@@ -5,8 +5,8 @@ from .sdxl_text_encoder import SDXLTextEncoder2, SDXLTextEncoder2StateDictConver
|
||||
|
||||
|
||||
class SD3TextEncoder1(SDTextEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, vocab_size=49408):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
|
||||
def forward(self, input_ids, clip_skip=2):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
|
||||
@@ -20,7 +20,7 @@ class SD3Prompter(Prompter):
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
|
||||
super().__init__()
|
||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
|
||||
|
||||
@@ -61,17 +61,17 @@ class SD3Prompter(Prompter):
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
prompt, pure_prompt = self.process_prompt(prompt, positive=positive, require_pure_prompt=True)
|
||||
|
||||
# CLIP
|
||||
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer_1, 77, device)
|
||||
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, text_encoder_2, self.tokenizer_2, 77, device)
|
||||
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer, 77, device)
|
||||
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(pure_prompt, text_encoder_2, self.tokenizer_2, 77, device)
|
||||
|
||||
# T5
|
||||
if text_encoder_3 is None:
|
||||
prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device)
|
||||
else:
|
||||
prompt_emb_3 = self.encode_prompt_using_t5(prompt, text_encoder_3, self.tokenizer_3, 256, device)
|
||||
prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device)
|
||||
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
|
||||
|
||||
# Merge
|
||||
|
||||
@@ -111,14 +111,27 @@ but make sure there is a correlation between the input and output.\n\
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
def process_prompt(self, prompt, positive=True):
|
||||
def add_textual_inversion_tokens(self, prompt):
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
return prompt
|
||||
|
||||
def del_textual_inversion_tokens(self, prompt):
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, "")
|
||||
return prompt
|
||||
|
||||
def process_prompt(self, prompt, positive=True, require_pure_prompt=False):
|
||||
prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt)
|
||||
if positive and self.translator is not None:
|
||||
prompt = self.translator(prompt)
|
||||
print(f"Your prompt is translated: \"{prompt}\"")
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
return prompt
|
||||
if require_pure_prompt:
|
||||
return prompt, pure_prompt
|
||||
else:
|
||||
return prompt
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from diffsynth import ModelManager, SD3ImagePipeline, download_models, load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
# Download models (automatically)
|
||||
# `models/stable_diffusion_3/sd3_medium_incl_clips.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/sd3_medium_incl_clips.safetensors)
|
||||
# `models/textual_inversion/verybadimagenegative_v1.3.pt`: [link](https://civitai.com/api/download/models/25820?type=Model&format=PickleTensor&size=full&fp=fp16)
|
||||
download_models(["StableDiffusion3_without_T5", "TextualInversion_VeryBadImageNegative_v1.3"])
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_textual_inversions("models/textual_inversion")
|
||||
model_manager.load_models(["models/stable_diffusion_3/sd3_medium_incl_clips.safetensors"])
|
||||
pipe = SD3ImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
|
||||
for seed in range(4):
|
||||
torch.manual_seed(seed)
|
||||
image = pipe(
|
||||
prompt="a girl, highly detailed, absurd res, perfect image",
|
||||
negative_prompt="verybadimagenegative_v1.3",
|
||||
cfg_scale=4.5,
|
||||
num_inference_steps=50, width=1024, height=1024,
|
||||
)
|
||||
image.save(f"image_with_textual_inversion_{seed}.jpg")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
image = pipe(
|
||||
prompt="a girl, highly detailed, absurd res, perfect image",
|
||||
negative_prompt="",
|
||||
cfg_scale=4.5,
|
||||
num_inference_steps=50, width=1024, height=1024,
|
||||
)
|
||||
image.save(f"image_without_textual_inversion_{seed}.jpg")
|
||||
Reference in New Issue
Block a user