mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
support video-to-video-translation
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from transformers import CLIPTokenizer
|
||||
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2
|
||||
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, load_state_dict
|
||||
import torch, os
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
def tokenize_long_prompt(tokenizer, prompt):
|
||||
@@ -36,49 +35,40 @@ def tokenize_long_prompt(tokenizer, prompt):
|
||||
return input_ids
|
||||
|
||||
|
||||
def load_textual_inversion(prompt):
|
||||
# TODO: This module is not enabled now.
|
||||
textual_inversion_files = os.listdir("models/textual_inversion")
|
||||
embeddings_768 = []
|
||||
embeddings_1280 = []
|
||||
for file_name in textual_inversion_files:
|
||||
if not file_name.endswith(".safetensors"):
|
||||
continue
|
||||
keyword = file_name[:-len(".safetensors")]
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, "")
|
||||
with safe_open(f"models/textual_inversion/{file_name}", framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
embedding = f.get_tensor(k).to(torch.float32)
|
||||
if embedding.shape[-1] == 768:
|
||||
embeddings_768.append(embedding)
|
||||
elif embedding.shape[-1] == 1280:
|
||||
embeddings_1280.append(embedding)
|
||||
|
||||
if len(embeddings_768)==0:
|
||||
embeddings_768 = torch.zeros((0, 768))
|
||||
else:
|
||||
embeddings_768 = torch.concat(embeddings_768, dim=0)
|
||||
|
||||
if len(embeddings_1280)==0:
|
||||
embeddings_1280 = torch.zeros((0, 1280))
|
||||
else:
|
||||
embeddings_1280 = torch.concat(embeddings_1280, dim=0)
|
||||
|
||||
return prompt, embeddings_768, embeddings_1280
|
||||
def search_for_embeddings(state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
if isinstance(state_dict[k], torch.Tensor):
|
||||
embeddings.append(state_dict[k])
|
||||
elif isinstance(state_dict[k], dict):
|
||||
embeddings += search_for_embeddings(state_dict[k])
|
||||
return embeddings
|
||||
|
||||
|
||||
class SDPrompter:
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.keyword_dict = {}
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda"):
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
|
||||
def load_textual_inversion(self, textual_inversion_dict):
|
||||
self.keyword_dict = {}
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
self.tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
|
||||
class SDXLPrompter:
|
||||
|
||||
Reference in New Issue
Block a user