mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
support SD3 textual inversion
This commit is contained in:
@@ -567,10 +567,22 @@ class ModelManager:
|
|||||||
if component == "sd3_text_encoder_3":
|
if component == "sd3_text_encoder_3":
|
||||||
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
|
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
|
||||||
continue
|
continue
|
||||||
self.model[component] = component_dict[component]()
|
elif component == "sd3_text_encoder_1":
|
||||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
# Add additional token embeddings to text encoder
|
||||||
self.model[component].to(self.torch_dtype).to(self.device)
|
token_embeddings = [state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"]]
|
||||||
self.model_path[component] = file_path
|
for keyword in self.textual_inversion_dict:
|
||||||
|
_, embeddings = self.textual_inversion_dict[keyword]
|
||||||
|
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
|
||||||
|
token_embeddings = torch.concat(token_embeddings, dim=0)
|
||||||
|
state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
|
||||||
|
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
|
||||||
|
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||||
|
self.model[component].to(self.torch_dtype).to(self.device)
|
||||||
|
else:
|
||||||
|
self.model[component] = component_dict[component]()
|
||||||
|
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||||
|
self.model[component].to(self.torch_dtype).to(self.device)
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
def load_stable_diffusion_3_t5(self, state_dict, file_path=""):
|
def load_stable_diffusion_3_t5(self, state_dict, file_path=""):
|
||||||
component = "sd3_text_encoder_3"
|
component = "sd3_text_encoder_3"
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from .sdxl_text_encoder import SDXLTextEncoder2, SDXLTextEncoder2StateDictConver
|
|||||||
|
|
||||||
|
|
||||||
class SD3TextEncoder1(SDTextEncoder):
|
class SD3TextEncoder1(SDTextEncoder):
|
||||||
def __init__(self):
|
def __init__(self, vocab_size=49408):
|
||||||
super().__init__()
|
super().__init__(vocab_size=vocab_size)
|
||||||
|
|
||||||
def forward(self, input_ids, clip_skip=2):
|
def forward(self, input_ids, clip_skip=2):
|
||||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class SD3Prompter(Prompter):
|
|||||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||||
tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
|
tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||||
self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
|
self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
|
||||||
|
|
||||||
@@ -61,17 +61,17 @@ class SD3Prompter(Prompter):
|
|||||||
positive=True,
|
positive=True,
|
||||||
device="cuda"
|
device="cuda"
|
||||||
):
|
):
|
||||||
prompt = self.process_prompt(prompt, positive=positive)
|
prompt, pure_prompt = self.process_prompt(prompt, positive=positive, require_pure_prompt=True)
|
||||||
|
|
||||||
# CLIP
|
# CLIP
|
||||||
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer_1, 77, device)
|
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer, 77, device)
|
||||||
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, text_encoder_2, self.tokenizer_2, 77, device)
|
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(pure_prompt, text_encoder_2, self.tokenizer_2, 77, device)
|
||||||
|
|
||||||
# T5
|
# T5
|
||||||
if text_encoder_3 is None:
|
if text_encoder_3 is None:
|
||||||
prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device)
|
prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device)
|
||||||
else:
|
else:
|
||||||
prompt_emb_3 = self.encode_prompt_using_t5(prompt, text_encoder_3, self.tokenizer_3, 256, device)
|
prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device)
|
||||||
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
|
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
|
||||||
|
|
||||||
# Merge
|
# Merge
|
||||||
|
|||||||
@@ -111,14 +111,27 @@ but make sure there is a correlation between the input and output.\n\
|
|||||||
if "beautiful_prompt" in model_manager.model:
|
if "beautiful_prompt" in model_manager.model:
|
||||||
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||||
|
|
||||||
def process_prompt(self, prompt, positive=True):
|
def add_textual_inversion_tokens(self, prompt):
|
||||||
for keyword in self.keyword_dict:
|
for keyword in self.keyword_dict:
|
||||||
if keyword in prompt:
|
if keyword in prompt:
|
||||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def del_textual_inversion_tokens(self, prompt):
|
||||||
|
for keyword in self.keyword_dict:
|
||||||
|
if keyword in prompt:
|
||||||
|
prompt = prompt.replace(keyword, "")
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def process_prompt(self, prompt, positive=True, require_pure_prompt=False):
|
||||||
|
prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt)
|
||||||
if positive and self.translator is not None:
|
if positive and self.translator is not None:
|
||||||
prompt = self.translator(prompt)
|
prompt = self.translator(prompt)
|
||||||
print(f"Your prompt is translated: \"{prompt}\"")
|
print(f"Your prompt is translated: \"{prompt}\"")
|
||||||
if positive and self.beautiful_prompt is not None:
|
if positive and self.beautiful_prompt is not None:
|
||||||
prompt = self.beautiful_prompt(prompt)
|
prompt = self.beautiful_prompt(prompt)
|
||||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||||
return prompt
|
if require_pure_prompt:
|
||||||
|
return prompt, pure_prompt
|
||||||
|
else:
|
||||||
|
return prompt
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
from diffsynth import ModelManager, SD3ImagePipeline, download_models, load_state_dict
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# Download models (automatically)
|
||||||
|
# `models/stable_diffusion_3/sd3_medium_incl_clips.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/sd3_medium_incl_clips.safetensors)
|
||||||
|
# `models/textual_inversion/verybadimagenegative_v1.3.pt`: [link](https://civitai.com/api/download/models/25820?type=Model&format=PickleTensor&size=full&fp=fp16)
|
||||||
|
download_models(["StableDiffusion3_without_T5", "TextualInversion_VeryBadImageNegative_v1.3"])
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||||
|
model_manager.load_textual_inversions("models/textual_inversion")
|
||||||
|
model_manager.load_models(["models/stable_diffusion_3/sd3_medium_incl_clips.safetensors"])
|
||||||
|
pipe = SD3ImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
|
||||||
|
for seed in range(4):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
image = pipe(
|
||||||
|
prompt="a girl, highly detailed, absurd res, perfect image",
|
||||||
|
negative_prompt="verybadimagenegative_v1.3",
|
||||||
|
cfg_scale=4.5,
|
||||||
|
num_inference_steps=50, width=1024, height=1024,
|
||||||
|
)
|
||||||
|
image.save(f"image_with_textual_inversion_{seed}.jpg")
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
image = pipe(
|
||||||
|
prompt="a girl, highly detailed, absurd res, perfect image",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=4.5,
|
||||||
|
num_inference_steps=50, width=1024, height=1024,
|
||||||
|
)
|
||||||
|
image.save(f"image_without_textual_inversion_{seed}.jpg")
|
||||||
Reference in New Issue
Block a user