mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
119 lines
4.3 KiB
Python
119 lines
4.3 KiB
Python
import torch
|
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
|
from .sd_text_encoder import SDTextEncoder
|
|
|
|
|
|
class FLUXTextEncoder1(SDTextEncoder):
|
|
def __init__(self, vocab_size=49408):
|
|
super().__init__(vocab_size=vocab_size)
|
|
|
|
def forward(self, input_ids, clip_skip=2):
|
|
embeds = self.token_embedding(input_ids) + self.position_embeds
|
|
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
|
for encoder_id, encoder in enumerate(self.encoders):
|
|
embeds = encoder(embeds, attn_mask=attn_mask)
|
|
if encoder_id + clip_skip == len(self.encoders):
|
|
hidden_states = embeds
|
|
embeds = self.final_layer_norm(embeds)
|
|
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
|
return embeds, pooled_embeds
|
|
|
|
@staticmethod
|
|
def state_dict_converter():
|
|
return FLUXTextEncoder1StateDictConverter()
|
|
|
|
class FLUXTextEncoder2(T5EncoderModel):
|
|
def __init__(self):
|
|
config = T5Config(
|
|
_name_or_path = ".",
|
|
architectures = ["T5EncoderModel"],
|
|
classifier_dropout = 0.0,
|
|
d_ff = 10240,
|
|
d_kv = 64,
|
|
d_model = 4096,
|
|
decoder_start_token_id = 0,
|
|
dense_act_fn = "gelu_new",
|
|
dropout_rate = 0.1,
|
|
eos_token_id = 1,
|
|
feed_forward_proj = "gated-gelu",
|
|
initializer_factor = 1.0,
|
|
is_encoder_decoder = True,
|
|
is_gated_act = True,
|
|
layer_norm_epsilon = 1e-06,
|
|
model_type = "t5",
|
|
num_decoder_layers = 24,
|
|
num_heads = 64,
|
|
num_layers = 24,
|
|
output_past = True,
|
|
pad_token_id = 0,
|
|
relative_attention_max_distance = 128,
|
|
relative_attention_num_buckets = 32,
|
|
tie_word_embeddings = False,
|
|
torch_dtype = "bfloat16",
|
|
transformers_version = "4.43.3",
|
|
use_cache = True,
|
|
vocab_size = 32128
|
|
)
|
|
super().__init__(config)
|
|
self.eval()
|
|
|
|
def forward(self, input_ids):
|
|
outputs = super().forward(input_ids=input_ids)
|
|
prompt_emb = outputs.last_hidden_state
|
|
return prompt_emb
|
|
|
|
@staticmethod
|
|
def state_dict_converter():
|
|
return FLUXTextEncoder2StateDictConverter()
|
|
|
|
|
|
class FLUXTextEncoder1StateDictConverter:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def from_diffusers(self, state_dict):
|
|
rename_dict = {
|
|
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
|
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
|
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
|
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
|
}
|
|
attn_rename_dict = {
|
|
"self_attn.q_proj": "attn.to_q",
|
|
"self_attn.k_proj": "attn.to_k",
|
|
"self_attn.v_proj": "attn.to_v",
|
|
"self_attn.out_proj": "attn.to_out",
|
|
"layer_norm1": "layer_norm1",
|
|
"layer_norm2": "layer_norm2",
|
|
"mlp.fc1": "fc1",
|
|
"mlp.fc2": "fc2",
|
|
}
|
|
state_dict_ = {}
|
|
for name in state_dict:
|
|
if name in rename_dict:
|
|
param = state_dict[name]
|
|
if name == "text_model.embeddings.position_embedding.weight":
|
|
param = param.reshape((1, param.shape[0], param.shape[1]))
|
|
state_dict_[rename_dict[name]] = param
|
|
elif name.startswith("text_model.encoder.layers."):
|
|
param = state_dict[name]
|
|
names = name.split(".")
|
|
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
|
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
|
state_dict_[name_] = param
|
|
return state_dict_
|
|
|
|
def from_civitai(self, state_dict):
|
|
return self.from_diffusers(state_dict)
|
|
|
|
class FLUXTextEncoder2StateDictConverter():
|
|
def __init__(self):
|
|
pass
|
|
|
|
def from_diffusers(self, state_dict):
|
|
state_dict_ = state_dict
|
|
return state_dict_
|
|
|
|
def from_civitai(self, state_dict):
|
|
return self.from_diffusers(state_dict)
|