diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 6cca984..eb1cde1 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -464,9 +464,9 @@ class FluxDiTStateDictConverter: name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) state_dict_[name_] = param else: - print(name) + pass else: - print(name) + pass for name in list(state_dict_.keys()): if ".proj_in_besides_attn." in name: name_ = name.replace(".proj_in_besides_attn.", ".linear.") @@ -570,6 +570,6 @@ class FluxDiTStateDictConverter: rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])] state_dict_[rename] = param else: - print(name) + pass return state_dict_ \ No newline at end of file diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index 0419364..8f0ed70 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -4,6 +4,7 @@ from .sdxl_unet import SDXLUNet from .sd_text_encoder import SDTextEncoder from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2 from .sd3_dit import SD3DiT +from .flux_dit import FluxDiT from .hunyuan_dit import HunyuanDiT @@ -17,6 +18,13 @@ class LoRAFromCivitai: def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0): + for key in state_dict: + if ".lora_up" in key: + return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha) + return self.convert_state_dict_AB(state_dict, lora_prefix, alpha) + + + def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0): renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "") state_dict_ = {} for key in state_dict: @@ -39,6 +47,29 @@ class LoRAFromCivitai: return state_dict_ + def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16): + state_dict_ = {} + for key in state_dict: + if ".lora_B." not in key: + continue + if not key.startswith(lora_prefix): + continue + weight_up = state_dict[key].to(device=device, dtype=torch_dtype) + weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_up, weight_down) + keys = key.split(".") + keys.pop(keys.index("lora_B")) + target_name = ".".join(keys) + target_name = target_name[len(lora_prefix):] + state_dict_[target_name] = lora_weight.cpu() + return state_dict_ + + def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None): state_dict_model = model.state_dict() state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha) @@ -134,6 +165,23 @@ class SDXLLoRAFromCivitai(LoRAFromCivitai): } +class FluxLoRAFromCivitai(LoRAFromCivitai): + def __init__(self): + super().__init__() + self.supported_model_classes = [FluxDiT, FluxDiT] + self.lora_prefix = ["lora_unet_", "transformer."] + self.renamed_lora_prefix = {} + self.special_keys = { + "single.blocks": "single_blocks", + "double.blocks": "double_blocks", + "img.attn": "img_attn", + "img.mlp": "img_mlp", + "img.mod": "img_mod", + "txt.attn": "txt_attn", + "txt.mlp": "txt_mlp", + "txt.mod": "txt_mod", + } + class GeneralLoRAFromPeft: def __init__(self): @@ -192,4 +240,8 @@ class GeneralLoRAFromPeft: return "", "" except: pass - return None \ No newline at end of file + return None + + +def get_lora_loaders(): + return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft(), FluxLoRAFromCivitai()] diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index f0e0a5d..2b364fb 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -10,7 +10,7 @@ from .sd_text_encoder import SDTextEncoder from .sd_unet import SDUNet from .sd_vae_encoder import SDVAEEncoder from .sd_vae_decoder import SDVAEDecoder -from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft +from .lora import get_lora_loaders from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2 from .sdxl_unet import SDXLUNet @@ -403,7 +403,7 @@ class ModelManager: if len(state_dict) == 0: state_dict = load_state_dict(file_path) for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): - for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]: + for lora in get_lora_loaders(): match_results = lora.match(model, state_dict) if match_results is not None: print(f" Adding LoRA to {model_name} ({model_path}).")