mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:32:27 +00:00
110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
from .base_prompter import BasePrompter
|
|
from ..models.wan_video_text_encoder import WanTextEncoder
|
|
from transformers import AutoTokenizer
|
|
import os, torch
|
|
import ftfy
|
|
import html
|
|
import string
|
|
import regex as re
|
|
|
|
|
|
def basic_clean(text):
|
|
text = ftfy.fix_text(text)
|
|
text = html.unescape(html.unescape(text))
|
|
return text.strip()
|
|
|
|
|
|
def whitespace_clean(text):
|
|
text = re.sub(r'\s+', ' ', text)
|
|
text = text.strip()
|
|
return text
|
|
|
|
|
|
def canonicalize(text, keep_punctuation_exact_string=None):
|
|
text = text.replace('_', ' ')
|
|
if keep_punctuation_exact_string:
|
|
text = keep_punctuation_exact_string.join(
|
|
part.translate(str.maketrans('', '', string.punctuation))
|
|
for part in text.split(keep_punctuation_exact_string))
|
|
else:
|
|
text = text.translate(str.maketrans('', '', string.punctuation))
|
|
text = text.lower()
|
|
text = re.sub(r'\s+', ' ', text)
|
|
return text.strip()
|
|
|
|
|
|
class HuggingfaceTokenizer:
|
|
|
|
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
|
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
|
self.name = name
|
|
self.seq_len = seq_len
|
|
self.clean = clean
|
|
|
|
# init tokenizer
|
|
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
|
self.vocab_size = self.tokenizer.vocab_size
|
|
|
|
def __call__(self, sequence, **kwargs):
|
|
return_mask = kwargs.pop('return_mask', False)
|
|
|
|
# arguments
|
|
_kwargs = {'return_tensors': 'pt'}
|
|
if self.seq_len is not None:
|
|
_kwargs.update({
|
|
'padding': 'max_length',
|
|
'truncation': True,
|
|
'max_length': self.seq_len
|
|
})
|
|
_kwargs.update(**kwargs)
|
|
|
|
# tokenization
|
|
if isinstance(sequence, str):
|
|
sequence = [sequence]
|
|
if self.clean:
|
|
sequence = [self._clean(u) for u in sequence]
|
|
ids = self.tokenizer(sequence, **_kwargs)
|
|
|
|
# output
|
|
if return_mask:
|
|
return ids.input_ids, ids.attention_mask
|
|
else:
|
|
return ids.input_ids
|
|
|
|
def _clean(self, text):
|
|
if self.clean == 'whitespace':
|
|
text = whitespace_clean(basic_clean(text))
|
|
elif self.clean == 'lower':
|
|
text = whitespace_clean(basic_clean(text)).lower()
|
|
elif self.clean == 'canonicalize':
|
|
text = canonicalize(basic_clean(text))
|
|
return text
|
|
|
|
|
|
class WanPrompter(BasePrompter):
|
|
|
|
def __init__(self, tokenizer_path=None, text_len=512):
|
|
super().__init__()
|
|
self.text_len = text_len
|
|
self.text_encoder = None
|
|
self.fetch_tokenizer(tokenizer_path)
|
|
|
|
def fetch_tokenizer(self, tokenizer_path=None):
|
|
if tokenizer_path is not None:
|
|
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
|
|
|
|
def fetch_models(self, text_encoder: WanTextEncoder = None):
|
|
self.text_encoder = text_encoder
|
|
|
|
def encode_prompt(self, prompt, positive=True, device="cuda"):
|
|
prompt = self.process_prompt(prompt, positive=positive)
|
|
|
|
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
|
ids = ids.to(device)
|
|
mask = mask.to(device)
|
|
seq_lens = mask.gt(0).sum(dim=1).long()
|
|
prompt_emb = self.text_encoder(ids, mask)
|
|
for i, v in enumerate(seq_lens):
|
|
prompt_emb[:, v:] = 0
|
|
return prompt_emb
|