ExVideo for AnimateDiff

This commit is contained in:
Artiprocher
2024-07-26 14:35:18 +08:00
parent f094cae7e9
commit a076adf592
7 changed files with 520 additions and 48 deletions

View File

@@ -8,9 +8,9 @@ class SDPrompter(Prompter):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True, max_length=99999999):
prompt = self.process_prompt(prompt, positive=positive)
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
input_ids = tokenize_long_prompt(self.tokenizer, prompt, max_length=max_length).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))

View File

@@ -3,12 +3,12 @@ from ..models import ModelManager
import os
def tokenize_long_prompt(tokenizer, prompt):
def tokenize_long_prompt(tokenizer, prompt, max_length=99999999):
# Get model_max_length from self.tokenizer
length = tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
tokenizer.model_max_length = 99999999
tokenizer.model_max_length = max_length
# Tokenize it!
input_ids = tokenizer(prompt, return_tensors="pt").input_ids