from .general import GeneralLoRALoader import torch, math class FluxLoRALoader(GeneralLoRALoader): def __init__(self, device="cpu", torch_dtype=torch.float32): super().__init__(device=device, torch_dtype=torch_dtype) self.diffusers_rename_dict = { "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.weight", "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.weight", "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.weight", "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.weight", "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.weight", "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.weight", "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.weight", "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.weight", "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.weight", "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.weight", "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.weight", "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.weight", "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.weight", "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.weight", "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.weight", "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.weight", "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.weight", "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.weight", "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.weight", "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.weight", "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.weight", "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.weight", "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.weight", "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.weight", "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.weight", "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.weight", "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.weight", "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.weight", "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.weight", "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.weight", "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.weight", "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.weight", "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.weight", "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.weight", "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.weight", "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.weight", "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.weight", "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.weight", "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.weight", "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.weight", } self.civitai_rename_dict = { "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.weight", "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.weight", "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.weight", "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.weight", "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.weight", "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.weight", "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.weight", "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.weight", "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.weight", "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.weight", "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.weight", "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.weight", "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.weight", "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.weight", "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.weight", "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.weight", "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.weight", "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.weight", "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.weight", "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.weight", "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.weight", "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.weight", "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.weight", "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.weight", "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.weight", "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.weight", } def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): super().fuse_lora_to_base_model(model, state_dict_lora, alpha) def convert_state_dict(self, state_dict): def guess_block_id(name,model_resource): if model_resource == 'civitai': names = name.split("_") for i in names: if i.isdigit(): return i, name.replace(f"_{i}_", "_blockid_") if model_resource == 'diffusers': names = name.split(".") for i in names: if i.isdigit(): return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.") return None, None def guess_resource(state_dict): for k in state_dict: if "lora_unet_" in k: return 'civitai' elif k.startswith("transformer."): return 'diffusers' else: None model_resource = guess_resource(state_dict) if model_resource is None: return state_dict rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict def guess_alpha(state_dict): for name, param in state_dict.items(): if ".alpha" in name: for suffix in [".lora_down.weight", ".lora_A.weight"]: name_ = name.replace(".alpha", suffix) if name_ in state_dict: lora_alpha = param.item() / state_dict[name_].shape[0] lora_alpha = math.sqrt(lora_alpha) return lora_alpha return 1 alpha = guess_alpha(state_dict) state_dict_ = {} for name, param in state_dict.items(): block_id, source_name = guess_block_id(name,model_resource) if alpha != 1: param *= alpha if source_name in rename_dict: target_name = rename_dict[source_name] target_name = target_name.replace(".blockid.", f".{block_id}.") state_dict_[target_name] = param else: state_dict_[name] = param if model_resource == 'diffusers': for name in list(state_dict_.keys()): if "single_blocks." in name and ".a_to_q." in name: mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) if mlp is None: dim = 4 if 'lora_A' in name: dim = 1 mlp = torch.zeros(dim * state_dict_[name].shape[0], *state_dict_[name].shape[1:], dtype=state_dict_[name].dtype) else: state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) if 'lora_A' in name: param = torch.concat([ state_dict_.pop(name), state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), mlp, ], dim=0) elif 'lora_B' in name: d, r = state_dict_[name].shape param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) param[:d, :r] = state_dict_.pop(name) param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) param[3*d:, 3*r:] = mlp else: param = torch.concat([ state_dict_.pop(name), state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), mlp, ], dim=0) name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") state_dict_[name_] = param for name in list(state_dict_.keys()): for component in ["a", "b"]: if f".{component}_to_q." in name: name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") concat_dim = 0 if 'lora_A' in name: param = torch.concat([ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], ], dim=0) elif 'lora_B' in name: origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] d, r = origin.shape # print(d, r) param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] else: param = torch.concat([ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], ], dim=0) state_dict_[name_] = param state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) return state_dict_