support FLUX

This commit is contained in:
Artiprocher
2024-08-16 20:04:10 +08:00
parent 1116e6dbc7
commit 99e11112a7
20 changed files with 230033 additions and 48 deletions

View File

@@ -1,9 +1,9 @@
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers import T5EncoderModel, T5Config
from .sd_text_encoder import SDTextEncoder
class FLUXTextEncoder1(SDTextEncoder):
class FluxTextEncoder1(SDTextEncoder):
def __init__(self, vocab_size=49408):
super().__init__(vocab_size=vocab_size)
@@ -20,40 +20,12 @@ class FLUXTextEncoder1(SDTextEncoder):
@staticmethod
def state_dict_converter():
return FLUXTextEncoder1StateDictConverter()
return FluxTextEncoder1StateDictConverter()
class FLUXTextEncoder2(T5EncoderModel):
def __init__(self):
config = T5Config(
_name_or_path = ".",
architectures = ["T5EncoderModel"],
classifier_dropout = 0.0,
d_ff = 10240,
d_kv = 64,
d_model = 4096,
decoder_start_token_id = 0,
dense_act_fn = "gelu_new",
dropout_rate = 0.1,
eos_token_id = 1,
feed_forward_proj = "gated-gelu",
initializer_factor = 1.0,
is_encoder_decoder = True,
is_gated_act = True,
layer_norm_epsilon = 1e-06,
model_type = "t5",
num_decoder_layers = 24,
num_heads = 64,
num_layers = 24,
output_past = True,
pad_token_id = 0,
relative_attention_max_distance = 128,
relative_attention_num_buckets = 32,
tie_word_embeddings = False,
torch_dtype = "bfloat16",
transformers_version = "4.43.3",
use_cache = True,
vocab_size = 32128
)
class FluxTextEncoder2(T5EncoderModel):
def __init__(self, config):
super().__init__(config)
self.eval()
@@ -64,10 +36,11 @@ class FLUXTextEncoder2(T5EncoderModel):
@staticmethod
def state_dict_converter():
return FLUXTextEncoder2StateDictConverter()
return FluxTextEncoder2StateDictConverter()
class FLUXTextEncoder1StateDictConverter:
class FluxTextEncoder1StateDictConverter:
def __init__(self):
pass
@@ -106,7 +79,9 @@ class FLUXTextEncoder1StateDictConverter:
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class FLUXTextEncoder2StateDictConverter():
class FluxTextEncoder2StateDictConverter():
def __init__(self):
pass