mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
Merge pull request #186 from modelscope/flux-lora
support flux lora inference
This commit is contained in:
@@ -464,9 +464,9 @@ class FluxDiTStateDictConverter:
|
|||||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||||
state_dict_[name_] = param
|
state_dict_[name_] = param
|
||||||
else:
|
else:
|
||||||
print(name)
|
pass
|
||||||
else:
|
else:
|
||||||
print(name)
|
pass
|
||||||
for name in list(state_dict_.keys()):
|
for name in list(state_dict_.keys()):
|
||||||
if ".proj_in_besides_attn." in name:
|
if ".proj_in_besides_attn." in name:
|
||||||
name_ = name.replace(".proj_in_besides_attn.", ".linear.")
|
name_ = name.replace(".proj_in_besides_attn.", ".linear.")
|
||||||
@@ -570,6 +570,6 @@ class FluxDiTStateDictConverter:
|
|||||||
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||||
state_dict_[rename] = param
|
state_dict_[rename] = param
|
||||||
else:
|
else:
|
||||||
print(name)
|
pass
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
@@ -4,6 +4,7 @@ from .sdxl_unet import SDXLUNet
|
|||||||
from .sd_text_encoder import SDTextEncoder
|
from .sd_text_encoder import SDTextEncoder
|
||||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||||
from .sd3_dit import SD3DiT
|
from .sd3_dit import SD3DiT
|
||||||
|
from .flux_dit import FluxDiT
|
||||||
from .hunyuan_dit import HunyuanDiT
|
from .hunyuan_dit import HunyuanDiT
|
||||||
|
|
||||||
|
|
||||||
@@ -17,6 +18,13 @@ class LoRAFromCivitai:
|
|||||||
|
|
||||||
|
|
||||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||||
|
for key in state_dict:
|
||||||
|
if ".lora_up" in key:
|
||||||
|
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
|
||||||
|
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
||||||
state_dict_ = {}
|
state_dict_ = {}
|
||||||
for key in state_dict:
|
for key in state_dict:
|
||||||
@@ -39,6 +47,29 @@ class LoRAFromCivitai:
|
|||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
||||||
|
state_dict_ = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if ".lora_B." not in key:
|
||||||
|
continue
|
||||||
|
if not key.startswith(lora_prefix):
|
||||||
|
continue
|
||||||
|
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||||
|
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
keys = key.split(".")
|
||||||
|
keys.pop(keys.index("lora_B"))
|
||||||
|
target_name = ".".join(keys)
|
||||||
|
target_name = target_name[len(lora_prefix):]
|
||||||
|
state_dict_[target_name] = lora_weight.cpu()
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
||||||
state_dict_model = model.state_dict()
|
state_dict_model = model.state_dict()
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
||||||
@@ -134,6 +165,23 @@ class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FluxLoRAFromCivitai(LoRAFromCivitai):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.supported_model_classes = [FluxDiT, FluxDiT]
|
||||||
|
self.lora_prefix = ["lora_unet_", "transformer."]
|
||||||
|
self.renamed_lora_prefix = {}
|
||||||
|
self.special_keys = {
|
||||||
|
"single.blocks": "single_blocks",
|
||||||
|
"double.blocks": "double_blocks",
|
||||||
|
"img.attn": "img_attn",
|
||||||
|
"img.mlp": "img_mlp",
|
||||||
|
"img.mod": "img_mod",
|
||||||
|
"txt.attn": "txt_attn",
|
||||||
|
"txt.mlp": "txt_mlp",
|
||||||
|
"txt.mod": "txt_mod",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class GeneralLoRAFromPeft:
|
class GeneralLoRAFromPeft:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -193,3 +241,7 @@ class GeneralLoRAFromPeft:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_loaders():
|
||||||
|
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft(), FluxLoRAFromCivitai()]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from .sd_text_encoder import SDTextEncoder
|
|||||||
from .sd_unet import SDUNet
|
from .sd_unet import SDUNet
|
||||||
from .sd_vae_encoder import SDVAEEncoder
|
from .sd_vae_encoder import SDVAEEncoder
|
||||||
from .sd_vae_decoder import SDVAEDecoder
|
from .sd_vae_decoder import SDVAEDecoder
|
||||||
from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft
|
from .lora import get_lora_loaders
|
||||||
|
|
||||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||||
from .sdxl_unet import SDXLUNet
|
from .sdxl_unet import SDXLUNet
|
||||||
@@ -403,7 +403,7 @@ class ModelManager:
|
|||||||
if len(state_dict) == 0:
|
if len(state_dict) == 0:
|
||||||
state_dict = load_state_dict(file_path)
|
state_dict = load_state_dict(file_path)
|
||||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||||
for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]:
|
for lora in get_lora_loaders():
|
||||||
match_results = lora.match(model, state_dict)
|
match_results = lora.match(model, state_dict)
|
||||||
if match_results is not None:
|
if match_results is not None:
|
||||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||||
|
|||||||
Reference in New Issue
Block a user