support video-to-video-translation

This commit is contained in:
Artiprocher
2023-12-21 17:11:58 +08:00
parent f7f4c1038e
commit c1453281df
20 changed files with 1659 additions and 427 deletions

View File

@@ -1,7 +1,6 @@
from transformers import CLIPTokenizer
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, load_state_dict
import torch, os
from safetensors import safe_open
def tokenize_long_prompt(tokenizer, prompt):
@@ -36,49 +35,40 @@ def tokenize_long_prompt(tokenizer, prompt):
return input_ids
def load_textual_inversion(prompt):
# TODO: This module is not enabled now.
textual_inversion_files = os.listdir("models/textual_inversion")
embeddings_768 = []
embeddings_1280 = []
for file_name in textual_inversion_files:
if not file_name.endswith(".safetensors"):
continue
keyword = file_name[:-len(".safetensors")]
if keyword in prompt:
prompt = prompt.replace(keyword, "")
with safe_open(f"models/textual_inversion/{file_name}", framework="pt", device="cpu") as f:
for k in f.keys():
embedding = f.get_tensor(k).to(torch.float32)
if embedding.shape[-1] == 768:
embeddings_768.append(embedding)
elif embedding.shape[-1] == 1280:
embeddings_1280.append(embedding)
if len(embeddings_768)==0:
embeddings_768 = torch.zeros((0, 768))
else:
embeddings_768 = torch.concat(embeddings_768, dim=0)
if len(embeddings_1280)==0:
embeddings_1280 = torch.zeros((0, 1280))
else:
embeddings_1280 = torch.concat(embeddings_1280, dim=0)
return prompt, embeddings_768, embeddings_1280
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
class SDPrompter:
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.keyword_dict = {}
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda"):
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
def load_textual_inversion(self, textual_inversion_dict):
self.keyword_dict = {}
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
self.tokenizer.add_tokens(additional_tokens)
class SDXLPrompter: