add downloader

This commit is contained in:
Artiprocher
2024-06-24 16:45:35 +08:00
parent 00f294454b
commit e9ec2f2706
26 changed files with 430 additions and 42 deletions

View File

@@ -1,14 +1,20 @@
from .utils import Prompter
from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer
import warnings
import warnings, os
class HunyuanDiTPrompter(Prompter):
def __init__(
self,
tokenizer_path="configs/hunyuan_dit/tokenizer",
tokenizer_t5_path="configs/hunyuan_dit/tokenizer_t5"
tokenizer_path=None,
tokenizer_t5_path=None
):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
if tokenizer_t5_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_t5_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer_t5")
super().__init__()
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
with warnings.catch_warnings():