mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support flux lora inference
This commit is contained in:
@@ -464,9 +464,9 @@ class FluxDiTStateDictConverter:
|
||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
print(name)
|
||||
pass
|
||||
else:
|
||||
print(name)
|
||||
pass
|
||||
for name in list(state_dict_.keys()):
|
||||
if ".proj_in_besides_attn." in name:
|
||||
name_ = name.replace(".proj_in_besides_attn.", ".linear.")
|
||||
@@ -570,6 +570,6 @@ class FluxDiTStateDictConverter:
|
||||
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||
state_dict_[rename] = param
|
||||
else:
|
||||
print(name)
|
||||
pass
|
||||
return state_dict_
|
||||
|
||||
@@ -4,6 +4,7 @@ from .sdxl_unet import SDXLUNet
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sd3_dit import SD3DiT
|
||||
from .flux_dit import FluxDiT
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
|
||||
@@ -17,6 +18,13 @@ class LoRAFromCivitai:
|
||||
|
||||
|
||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
for key in state_dict:
|
||||
if ".lora_up" in key:
|
||||
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
|
||||
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
|
||||
|
||||
|
||||
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
@@ -39,6 +47,29 @@ class LoRAFromCivitai:
|
||||
return state_dict_
|
||||
|
||||
|
||||
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
keys = key.split(".")
|
||||
keys.pop(keys.index("lora_B"))
|
||||
target_name = ".".join(keys)
|
||||
target_name = target_name[len(lora_prefix):]
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
||||
state_dict_model = model.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
||||
@@ -134,6 +165,23 @@ class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
||||
}
|
||||
|
||||
|
||||
class FluxLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [FluxDiT, FluxDiT]
|
||||
self.lora_prefix = ["lora_unet_", "transformer."]
|
||||
self.renamed_lora_prefix = {}
|
||||
self.special_keys = {
|
||||
"single.blocks": "single_blocks",
|
||||
"double.blocks": "double_blocks",
|
||||
"img.attn": "img_attn",
|
||||
"img.mlp": "img_mlp",
|
||||
"img.mod": "img_mod",
|
||||
"txt.attn": "txt_attn",
|
||||
"txt.mlp": "txt_mlp",
|
||||
"txt.mod": "txt_mod",
|
||||
}
|
||||
|
||||
|
||||
class GeneralLoRAFromPeft:
|
||||
def __init__(self):
|
||||
@@ -192,4 +240,8 @@ class GeneralLoRAFromPeft:
|
||||
return "", ""
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def get_lora_loaders():
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft(), FluxLoRAFromCivitai()]
|
||||
|
||||
@@ -10,7 +10,7 @@ from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
from .sd_vae_encoder import SDVAEEncoder
|
||||
from .sd_vae_decoder import SDVAEDecoder
|
||||
from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft
|
||||
from .lora import get_lora_loaders
|
||||
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sdxl_unet import SDXLUNet
|
||||
@@ -403,7 +403,7 @@ class ModelManager:
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]:
|
||||
for lora in get_lora_loaders():
|
||||
match_results = lora.match(model, state_dict)
|
||||
if match_results is not None:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
|
||||
Reference in New Issue
Block a user