diff --git a/diffsynth/utils/lora/flux.py b/diffsynth/utils/lora/flux.py index 502c5fd..2e1d3fd 100644 --- a/diffsynth/utils/lora/flux.py +++ b/diffsynth/utils/lora/flux.py @@ -202,3 +202,99 @@ class FluxLoRALoader(GeneralLoRALoader): state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) return state_dict_ + + +class FluxLoRAConverter: + def __init__(self): + pass + + @staticmethod + def align_to_opensource_format(state_dict, alpha=None): + prefix_rename_dict = { + "single_blocks": "lora_unet_single_blocks", + "blocks": "lora_unet_double_blocks", + } + middle_rename_dict = { + "norm.linear": "modulation_lin", + "to_qkv_mlp": "linear1", + "proj_out": "linear2", + + "norm1_a.linear": "img_mod_lin", + "norm1_b.linear": "txt_mod_lin", + "attn.a_to_qkv": "img_attn_qkv", + "attn.b_to_qkv": "txt_attn_qkv", + "attn.a_to_out": "img_attn_proj", + "attn.b_to_out": "txt_attn_proj", + "ff_a.0": "img_mlp_0", + "ff_a.2": "img_mlp_2", + "ff_b.0": "txt_mlp_0", + "ff_b.2": "txt_mlp_2", + } + suffix_rename_dict = { + "lora_B.weight": "lora_up.weight", + "lora_A.weight": "lora_down.weight", + } + state_dict_ = {} + for name, param in state_dict.items(): + names = name.split(".") + if names[-2] != "lora_A" and names[-2] != "lora_B": + names.pop(-2) + prefix = names[0] + middle = ".".join(names[2:-2]) + suffix = ".".join(names[-2:]) + block_id = names[1] + if middle not in middle_rename_dict: + continue + rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix] + state_dict_[rename] = param + if rename.endswith("lora_up.weight"): + lora_alpha = alpha if alpha is not None else param.shape[-1] + state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((lora_alpha,))[0] + return state_dict_ + + @staticmethod + def align_to_diffsynth_format(state_dict): + rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", + } + def guess_block_id(name): + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + return None, None + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name) + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + return state_dict_