diff --git a/diffsynth/lora/flux_lora.py b/diffsynth/lora/flux_lora.py index cc9d725..c6e5115 100644 --- a/diffsynth/lora/flux_lora.py +++ b/diffsynth/lora/flux_lora.py @@ -10,9 +10,51 @@ class FluxLoRALoader(GeneralLoRALoader): def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): super().load(model, state_dict_lora, alpha) - def convert_state_dict(self, state_dict): - # TODO: support other lora format - rename_dict = { + + self.diffusers_rename_dict = { + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight", + } + + self.civitai_rename_dict = { "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", @@ -40,25 +82,55 @@ class FluxLoRALoader(GeneralLoRALoader): "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", } - def guess_block_id(name): - names = name.split("_") - for i in names: - if i.isdigit(): - return i, name.replace(f"_{i}_", "_blockid_") + + def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + super().load(model, state_dict_lora, alpha) + + + def convert_state_dict(self,state_dict): + + def guess_block_id(name,model_resource): + if model_resource == 'civitai': + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + if model_resource == 'diffusers': + names = name.split(".") + for i in names: + if i.isdigit(): + return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.") return None, None + + def guess_resource(state_dict): + for k in state_dict: + if "lora_unet_" in k: + return 'civitai' + elif "transformer." in k: + return 'diffusers' + else: + None + + model_resource = guess_resource(state_dict) + + rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict def guess_alpha(state_dict): - for name, param in state_dict.items(): - if ".alpha" in name: - name_ = name.replace(".alpha", ".lora_down.weight") - if name_ in state_dict: - lora_alpha = param.item() / state_dict[name_].shape[0] - lora_alpha = math.sqrt(lora_alpha) - return lora_alpha - return 1 + for name, param in state_dict.items(): + if ".alpha" in name: + for suffix in [".lora_down.weight", ".lora_A.weight"]: + name_ = name.replace(".alpha", suffix) + if name_ in state_dict: + lora_alpha = param.item() / state_dict[name_].shape[0] + lora_alpha = math.sqrt(lora_alpha) + return lora_alpha + + return 1 + alpha = guess_alpha(state_dict) + state_dict_ = {} for name, param in state_dict.items(): - block_id, source_name = guess_block_id(name) + block_id, source_name = guess_block_id(name,model_resource) if alpha != 1: param *= alpha if source_name in rename_dict: @@ -67,6 +139,72 @@ class FluxLoRALoader(GeneralLoRALoader): state_dict_[target_name] = param else: state_dict_[name] = param + + if model_resource == 'diffusers': + for name in list(state_dict_.keys()): + if "single_blocks." in name and ".a_to_q." in name: + mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: + dim = 4 + if 'lora_A' in name: + dim = 1 + mlp = torch.zeros(dim * state_dict_[name].shape[0], + *state_dict_[name].shape[1:], + dtype=state_dict_[name].dtype) + else: + state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) + if 'lora_A' in name: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + elif 'lora_B' in name: + d, r = state_dict_[name].shape + param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) + param[:d, :r] = state_dict_.pop(name) + param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) + param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) + param[3*d:, 3*r:] = mlp + else: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") + state_dict_[name_] = param + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + concat_dim = 0 + if 'lora_A' in name: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + elif 'lora_B' in name: + origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + d, r = origin.shape + # print(d, r) + param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) + param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] + param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] + else: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) return state_dict_