mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
Merge pull request #282 from modelscope/lora-patch-2
support resume from opensource format
This commit is contained in:
@@ -306,6 +306,53 @@ class FluxLoRAConverter:
|
|||||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
|
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_diffsynth_format(state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||||
|
}
|
||||||
|
def guess_block_id(name):
|
||||||
|
names = name.split("_")
|
||||||
|
for i in names:
|
||||||
|
if i.isdigit():
|
||||||
|
return i, name.replace(f"_{i}_", "_blockid_")
|
||||||
|
return None, None
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
block_id, source_name = guess_block_id(name)
|
||||||
|
if source_name in rename_dict:
|
||||||
|
target_name = rename_dict[source_name]
|
||||||
|
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
||||||
|
state_dict_[target_name] = param
|
||||||
|
else:
|
||||||
|
state_dict_[name] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
def get_lora_loaders():
|
def get_lora_loaders():
|
||||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class LightningModelForT2ILoRA(pl.LightningModule):
|
|||||||
self.pipe.denoising_model().train()
|
self.pipe.denoising_model().train()
|
||||||
|
|
||||||
|
|
||||||
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None):
|
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, state_dict_converter=None):
|
||||||
# Add LoRA to UNet
|
# Add LoRA to UNet
|
||||||
self.lora_alpha = lora_alpha
|
self.lora_alpha = lora_alpha
|
||||||
if init_lora_weights == "kaiming":
|
if init_lora_weights == "kaiming":
|
||||||
@@ -55,6 +55,8 @@ class LightningModelForT2ILoRA(pl.LightningModule):
|
|||||||
# Lora pretrained lora weights
|
# Lora pretrained lora weights
|
||||||
if pretrained_lora_path is not None:
|
if pretrained_lora_path is not None:
|
||||||
state_dict = load_state_dict(pretrained_lora_path)
|
state_dict = load_state_dict(pretrained_lora_path)
|
||||||
|
if state_dict_converter is not None:
|
||||||
|
state_dict = state_dict_converter(state_dict)
|
||||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||||
all_keys = [i for i, _ in model.named_parameters()]
|
all_keys = [i for i, _ in model.named_parameters()]
|
||||||
num_updated_keys = len(all_keys) - len(missing_keys)
|
num_updated_keys = len(all_keys) - len(missing_keys)
|
||||||
|
|||||||
@@ -40,7 +40,8 @@ class LightningModel(LightningModelForT2ILoRA):
|
|||||||
lora_alpha=lora_alpha,
|
lora_alpha=lora_alpha,
|
||||||
lora_target_modules=lora_target_modules,
|
lora_target_modules=lora_target_modules,
|
||||||
init_lora_weights=init_lora_weights,
|
init_lora_weights=init_lora_weights,
|
||||||
pretrained_lora_path=pretrained_lora_path
|
pretrained_lora_path=pretrained_lora_path,
|
||||||
|
state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user