add downloader

This commit is contained in:
Artiprocher
2024-06-24 16:45:35 +08:00
parent 00f294454b
commit e9ec2f2706
26 changed files with 430 additions and 42 deletions

View File

@@ -1,14 +1,20 @@
from .utils import Prompter
from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer
import warnings
import warnings, os
class HunyuanDiTPrompter(Prompter):
def __init__(
self,
tokenizer_path="configs/hunyuan_dit/tokenizer",
tokenizer_t5_path="configs/hunyuan_dit/tokenizer_t5"
tokenizer_path=None,
tokenizer_t5_path=None
):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
if tokenizer_t5_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_t5_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer_t5")
super().__init__()
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
with warnings.catch_warnings():

View File

@@ -1,10 +1,14 @@
from .utils import Prompter, tokenize_long_prompt
from transformers import CLIPTokenizer
from ..models import SDTextEncoder
import os
class SDPrompter(Prompter):
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
def __init__(self, tokenizer_path=None):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)

View File

@@ -1,15 +1,21 @@
from .utils import Prompter, tokenize_long_prompt
from transformers import CLIPTokenizer
from ..models import SDXLTextEncoder, SDXLTextEncoder2
import torch
import torch, os
class SDXLPrompter(Prompter):
def __init__(
self,
tokenizer_path="configs/stable_diffusion/tokenizer",
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
tokenizer_path=None,
tokenizer_2_path=None
):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
if tokenizer_2_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_xl/tokenizer_2")
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)

View File

@@ -36,7 +36,7 @@ def tokenize_long_prompt(tokenizer, prompt):
class BeautifulPrompt:
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
def __init__(self, tokenizer_path=None, model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
@@ -62,7 +62,7 @@ class BeautifulPrompt:
class Translator:
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
def __init__(self, tokenizer_path=None, model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model