mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 16:18:13 +00:00
support FLUX
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
|
||||
|
||||
class FLUXTextEncoder1(SDTextEncoder):
|
||||
class FluxTextEncoder1(SDTextEncoder):
|
||||
def __init__(self, vocab_size=49408):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
|
||||
@@ -20,40 +20,12 @@ class FLUXTextEncoder1(SDTextEncoder):
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FLUXTextEncoder1StateDictConverter()
|
||||
return FluxTextEncoder1StateDictConverter()
|
||||
|
||||
class FLUXTextEncoder2(T5EncoderModel):
|
||||
def __init__(self):
|
||||
config = T5Config(
|
||||
_name_or_path = ".",
|
||||
architectures = ["T5EncoderModel"],
|
||||
classifier_dropout = 0.0,
|
||||
d_ff = 10240,
|
||||
d_kv = 64,
|
||||
d_model = 4096,
|
||||
decoder_start_token_id = 0,
|
||||
dense_act_fn = "gelu_new",
|
||||
dropout_rate = 0.1,
|
||||
eos_token_id = 1,
|
||||
feed_forward_proj = "gated-gelu",
|
||||
initializer_factor = 1.0,
|
||||
is_encoder_decoder = True,
|
||||
is_gated_act = True,
|
||||
layer_norm_epsilon = 1e-06,
|
||||
model_type = "t5",
|
||||
num_decoder_layers = 24,
|
||||
num_heads = 64,
|
||||
num_layers = 24,
|
||||
output_past = True,
|
||||
pad_token_id = 0,
|
||||
relative_attention_max_distance = 128,
|
||||
relative_attention_num_buckets = 32,
|
||||
tie_word_embeddings = False,
|
||||
torch_dtype = "bfloat16",
|
||||
transformers_version = "4.43.3",
|
||||
use_cache = True,
|
||||
vocab_size = 32128
|
||||
)
|
||||
|
||||
|
||||
class FluxTextEncoder2(T5EncoderModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.eval()
|
||||
|
||||
@@ -64,10 +36,11 @@ class FLUXTextEncoder2(T5EncoderModel):
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FLUXTextEncoder2StateDictConverter()
|
||||
return FluxTextEncoder2StateDictConverter()
|
||||
|
||||
|
||||
class FLUXTextEncoder1StateDictConverter:
|
||||
|
||||
class FluxTextEncoder1StateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -106,7 +79,9 @@ class FLUXTextEncoder1StateDictConverter:
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
class FLUXTextEncoder2StateDictConverter():
|
||||
|
||||
|
||||
class FluxTextEncoder2StateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user