mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
flux compatibility patch
This commit is contained in:
@@ -317,6 +317,13 @@ flux_series = [
|
|||||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
# Supported due to historical reasons.
|
||||||
|
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
|
||||||
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
||||||
|
|||||||
@@ -89,4 +89,107 @@ def FluxDiTStateDictConverter(state_dict):
|
|||||||
state_dict_[rename] = state_dict[original_name]
|
state_dict_[rename] = state_dict[original_name]
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def FluxDiTStateDictConverterFromDiffusers(state_dict):
|
||||||
|
global_rename_dict = {
|
||||||
|
"context_embedder": "context_embedder",
|
||||||
|
"x_embedder": "x_embedder",
|
||||||
|
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||||
|
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||||
|
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||||
|
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||||
|
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||||
|
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||||
|
"norm_out.linear": "final_norm_out.linear",
|
||||||
|
"proj_out": "final_proj_out",
|
||||||
|
}
|
||||||
|
rename_dict = {
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
"norm1.linear": "norm1_a.linear",
|
||||||
|
"norm1_context.linear": "norm1_b.linear",
|
||||||
|
"attn.to_q": "attn.a_to_q",
|
||||||
|
"attn.to_k": "attn.a_to_k",
|
||||||
|
"attn.to_v": "attn.a_to_v",
|
||||||
|
"attn.to_out.0": "attn.a_to_out",
|
||||||
|
"attn.add_q_proj": "attn.b_to_q",
|
||||||
|
"attn.add_k_proj": "attn.b_to_k",
|
||||||
|
"attn.add_v_proj": "attn.b_to_v",
|
||||||
|
"attn.to_add_out": "attn.b_to_out",
|
||||||
|
"ff.net.0.proj": "ff_a.0",
|
||||||
|
"ff.net.2": "ff_a.2",
|
||||||
|
"ff_context.net.0.proj": "ff_b.0",
|
||||||
|
"ff_context.net.2": "ff_b.2",
|
||||||
|
"attn.norm_q": "attn.norm_q_a",
|
||||||
|
"attn.norm_k": "attn.norm_k_a",
|
||||||
|
"attn.norm_added_q": "attn.norm_q_b",
|
||||||
|
"attn.norm_added_k": "attn.norm_k_b",
|
||||||
|
}
|
||||||
|
rename_dict_single = {
|
||||||
|
"attn.to_q": "a_to_q",
|
||||||
|
"attn.to_k": "a_to_k",
|
||||||
|
"attn.to_v": "a_to_v",
|
||||||
|
"attn.norm_q": "norm_q_a",
|
||||||
|
"attn.norm_k": "norm_k_a",
|
||||||
|
"norm.linear": "norm.linear",
|
||||||
|
"proj_mlp": "proj_in_besides_attn",
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if name.endswith(".weight") or name.endswith(".bias"):
|
||||||
|
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||||
|
prefix = name[:-len(suffix)]
|
||||||
|
if prefix in global_rename_dict:
|
||||||
|
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||||
|
elif prefix.startswith("transformer_blocks."):
|
||||||
|
names = prefix.split(".")
|
||||||
|
names[0] = "blocks"
|
||||||
|
middle = ".".join(names[2:])
|
||||||
|
if middle in rename_dict:
|
||||||
|
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
elif prefix.startswith("single_transformer_blocks."):
|
||||||
|
names = prefix.split(".")
|
||||||
|
names[0] = "single_blocks"
|
||||||
|
middle = ".".join(names[2:])
|
||||||
|
if middle in rename_dict_single:
|
||||||
|
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
if "single_blocks." in name and ".a_to_q." in name:
|
||||||
|
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||||
|
if mlp is None:
|
||||||
|
mlp = torch.zeros(4 * state_dict_[name].shape[0],
|
||||||
|
*state_dict_[name].shape[1:],
|
||||||
|
dtype=state_dict_[name].dtype)
|
||||||
|
else:
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_.pop(name),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||||
|
mlp,
|
||||||
|
], dim=0)
|
||||||
|
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||||
|
state_dict_[name_] = param
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
for component in ["a", "b"]:
|
||||||
|
if f".{component}_to_q." in name:
|
||||||
|
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||||
|
], dim=0)
|
||||||
|
state_dict_[name_] = param
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||||
return state_dict_
|
return state_dict_
|
||||||
Reference in New Issue
Block a user