From 2d23c897c20c60a877b834b689f80e786b5d7495 Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Tue, 18 Nov 2025 21:29:35 +0800 Subject: [PATCH] add: LoRA Encoder --- diffsynth/models/flux_lora_encoder.py | 71 +------------------ diffsynth/pipelines/flux_image.py | 32 --------- .../model_inference/FLUX.1-dev-LoRA-Fusion.py | 2 - 3 files changed, 2 insertions(+), 103 deletions(-) diff --git a/diffsynth/models/flux_lora_encoder.py b/diffsynth/models/flux_lora_encoder.py index 2be5dbb..13589b0 100644 --- a/diffsynth/models/flux_lora_encoder.py +++ b/diffsynth/models/flux_lora_encoder.py @@ -468,80 +468,13 @@ class LoRAEmbedder(torch.nn.Module): "type": suffix, }) return lora_patterns - - def get_lora_param_pair(self, lora, name, dim, device, dtype): - key_A = name + ".lora_A.default.weight" - key_B = name + ".lora_B.default.weight" - if key_A in lora and key_B in lora: - return lora[key_A], lora[key_B] - if "to_qkv" in name: - base_name = name.replace("to_qkv", "") - suffixes = ["to_q", "to_k", "to_v"] - - found_As = [] - found_Bs = [] - - all_found = True - for suffix in suffixes: - sub_name = base_name + suffix - k_A = sub_name + ".lora_A.default.weight" - k_B = sub_name + ".lora_B.default.weight" - - if k_A in lora and k_B in lora: - found_As.append(lora[k_A]) - found_Bs.append(lora[k_B]) - else: - all_found = False - break - if all_found: - pass - - rank = 16 - for k, v in lora.items(): - if "lora_A" in k: - rank = v.shape[0] - device = v.device - dtype = v.dtype - break - - lora_A = torch.zeros((rank, dim[0]), device=device, dtype=dtype) - lora_B = torch.zeros((dim[1], rank), device=device, dtype=dtype) - - return lora_A, lora_B - def forward(self, lora): lora_emb = [] - device = None - dtype = None - for v in lora.values(): - device = v.device - dtype = v.dtype - break - for lora_pattern in self.lora_patterns: name, layer_type = lora_pattern["name"], lora_pattern["type"] - dim = lora_pattern["dim"] - - lora_A, lora_B = self.get_lora_param_pair(lora, name, dim, device, dtype) - - if "to_qkv" in name and (lora_A is None or (torch.equal(lora_A, torch.zeros_like(lora_A)))): - base_name = name.replace("to_qkv", "") - try: - q_name = base_name + "to_q" - k_name = base_name + "to_k" - v_name = base_name + "to_v" - - real_A = lora[q_name + ".lora_A.default.weight"] - B_q = lora[q_name + ".lora_B.default.weight"] - B_k = lora[k_name + ".lora_B.default.weight"] - B_v = lora[v_name + ".lora_B.default.weight"] - real_B = torch.cat([B_q, B_k, B_v], dim=0) - - lora_A, lora_B = real_A, real_B - except KeyError: - pass - + lora_A = lora[name + ".lora_A.weight"] + lora_B = lora[name + ".lora_B.weight"] lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B) lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out) lora_emb.append(lora_out) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 9ac0373..dcd9e8e 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -106,38 +106,6 @@ class FluxImagePipeline(BasePipeline): def enable_lora_magic(self): pass - def load_lora(self, model, lora_config, alpha=1, hotload=False): - if isinstance(lora_config, str): - path = lora_config - else: - lora_config.download_if_necessary() - path = lora_config.path - - state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") - loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) - state_dict = loader.convert_state_dict(state_dict) - loaded_count = 0 - for key in tqdm(state_dict, desc="Applying LoRA"): - if ".lora_A." in key: - layer_name = key.split(".lora_A.")[0] - module = model - try: - parts = layer_name.split(".") - for part in parts: - if part.isdigit(): - module = module[int(part)] - else: - module = getattr(module, part) - except AttributeError: - continue - - w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) - w_b_key = key.replace("lora_A", "lora_B") - if w_b_key not in state_dict: continue - w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) - delta_w = torch.mm(w_b, w_a) - module.weight.data += delta_w * alpha - loaded_count += 1 @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py b/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py index 5339230..9d0b189 100644 --- a/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py +++ b/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py @@ -18,12 +18,10 @@ pipe.enable_lora_magic() pipe.load_lora( pipe.dit, ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"), - hotload=True, ) pipe.load_lora( pipe.dit, ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"), - hotload=True, ) image = pipe(prompt="a cat", seed=0) image.save("image_fused.jpg")