mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
32 lines
1.4 KiB
Python
32 lines
1.4 KiB
Python
def FluxTextEncoderClipStateDictConverter(state_dict):
|
|
rename_dict = {
|
|
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
|
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
|
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
|
"text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
|
}
|
|
attn_rename_dict = {
|
|
"self_attn.q_proj": "attn.to_q",
|
|
"self_attn.k_proj": "attn.to_k",
|
|
"self_attn.v_proj": "attn.to_v",
|
|
"self_attn.out_proj": "attn.to_out",
|
|
"layer_norm1": "layer_norm1",
|
|
"layer_norm2": "layer_norm2",
|
|
"mlp.fc1": "fc1",
|
|
"mlp.fc2": "fc2",
|
|
}
|
|
state_dict_ = {}
|
|
for name in state_dict:
|
|
if name in rename_dict:
|
|
param = state_dict[name]
|
|
if name == "text_model.embeddings.position_embedding.weight":
|
|
param = param.reshape((1, param.shape[0], param.shape[1]))
|
|
state_dict_[rename_dict[name]] = param
|
|
elif name.startswith("text_model.encoder.layers."):
|
|
param = state_dict[name]
|
|
names = name.split(".")
|
|
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
|
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
|
state_dict_[name_] = param
|
|
return state_dict_
|