mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
import torch
|
|
from transformers import T5EncoderModel, T5Config
|
|
|
|
|
|
class FluxTextEncoderT5(T5EncoderModel):
|
|
def __init__(self):
|
|
config = T5Config(**{
|
|
"architectures": [
|
|
"T5EncoderModel"
|
|
],
|
|
"classifier_dropout": 0.0,
|
|
"d_ff": 10240,
|
|
"d_kv": 64,
|
|
"d_model": 4096,
|
|
"decoder_start_token_id": 0,
|
|
"dense_act_fn": "gelu_new",
|
|
"dropout_rate": 0.1,
|
|
"dtype": "bfloat16",
|
|
"eos_token_id": 1,
|
|
"feed_forward_proj": "gated-gelu",
|
|
"initializer_factor": 1.0,
|
|
"is_encoder_decoder": True,
|
|
"is_gated_act": True,
|
|
"layer_norm_epsilon": 1e-06,
|
|
"model_type": "t5",
|
|
"num_decoder_layers": 24,
|
|
"num_heads": 64,
|
|
"num_layers": 24,
|
|
"output_past": True,
|
|
"pad_token_id": 0,
|
|
"relative_attention_max_distance": 128,
|
|
"relative_attention_num_buckets": 32,
|
|
"tie_word_embeddings": False,
|
|
"transformers_version": "4.57.1",
|
|
"use_cache": True,
|
|
"vocab_size": 32128
|
|
})
|
|
super().__init__(config)
|
|
|
|
def forward(self, input_ids):
|
|
outputs = super().forward(input_ids=input_ids)
|
|
prompt_emb = outputs.last_hidden_state
|
|
return prompt_emb
|