From 3f9e9cad9d6d8e2fb3b0ed1e4e1da518861f18ac Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Tue, 18 Nov 2025 20:37:14 +0800 Subject: [PATCH] fix:flux --- diffsynth/pipelines/flux_image.py | 100 +++--------------------------- 1 file changed, 10 insertions(+), 90 deletions(-) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 78f0cb1..9ac0373 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -106,41 +106,7 @@ class FluxImagePipeline(BasePipeline): def enable_lora_magic(self): pass - # def load_lora(self, model, lora_config, alpha=1, hotload=False): - # if isinstance(lora_config, str): - # path = lora_config - # else: - # lora_config.download_if_necessary() - # path = lora_config.path - - # state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") - # loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) - # state_dict = loader.convert_state_dict(state_dict) - # loaded_count = 0 - # for key in tqdm(state_dict, desc="Applying LoRA"): - # if ".lora_A." in key: - # layer_name = key.split(".lora_A.")[0] - # module = model - # try: - # parts = layer_name.split(".") - # for part in parts: - # if part.isdigit(): - # module = module[int(part)] - # else: - # module = getattr(module, part) - # except AttributeError: - # continue - - # w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) - # w_b_key = key.replace("lora_A", "lora_B") - # if w_b_key not in state_dict: continue - # w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) - # delta_w = torch.mm(w_b, w_a) - # module.weight.data += delta_w * alpha - # loaded_count += 1 - - - def load_lora(self, model, lora_config, alpha=1.0, hotload=False): + def load_lora(self, model, lora_config, alpha=1, hotload=False): if isinstance(lora_config, str): path = lora_config else: @@ -150,74 +116,28 @@ class FluxImagePipeline(BasePipeline): state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) state_dict = loader.convert_state_dict(state_dict) - - print(f"Merging LoRA weights from {path}...") loaded_count = 0 - - # [新增] 键名映射表,处理 FW2 Loader 与 DiT 模型名称不一致的情况 - # 针对 Single Blocks 常见的命名差异进行修正 - key_mapping = { - ".linear1.": ".to_qkv_mlp.", # 常见差异点 1 - ".linear2.": ".proj_out.", # 常见差异点 2 - ".modulation.lin.": ".norm.linear." # 常见差异点 3 - } - for key in tqdm(state_dict, desc="Applying LoRA"): if ".lora_A." in key: layer_name = key.split(".lora_A.")[0] - - # [新增] 尝试应用键名修正 - target_layer_name = layer_name - for src, dst in key_mapping.items(): - if src in target_layer_name: - target_layer_name = target_layer_name.replace(src, dst) - - # 在模型中查找层 module = model try: - parts = target_layer_name.split(".") + parts = layer_name.split(".") for part in parts: if part.isdigit(): module = module[int(part)] else: module = getattr(module, part) except AttributeError: - # 如果修正后还是找不到,尝试原始名称(作为保底) - try: - module = model - parts = layer_name.split(".") - for part in parts: - if part.isdigit(): - module = module[int(part)] - else: - module = getattr(module, part) - except AttributeError: - # 确实找不到,跳过并打印警告(可选) - # print(f"Warning: Could not find layer for {layer_name}") - continue - - # 获取 LoRA 参数并计算增量 - try: - w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) - w_b_key = key.replace("lora_A", "lora_B") - if w_b_key not in state_dict: continue - w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) - - # 检查形状是否匹配 (非常重要,防止 broadcasting 错误掩盖问题) - # Linear weight: (out, in). B@A: (out, in) - delta_w = torch.mm(w_b, w_a) - if delta_w.shape != module.weight.shape: - # 形状不匹配通常意味着 QKV 融合/分离状态不一致 - # 简单跳过或尝试转置(视具体情况,这里保守跳过) - continue - - module.weight.data += delta_w * alpha - loaded_count += 1 - except Exception as e: continue - - print(f"Applied LoRA to {loaded_count} layers.") - + + w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) + w_b_key = key.replace("lora_A", "lora_B") + if w_b_key not in state_dict: continue + w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) + delta_w = torch.mm(w_b, w_a) + module.weight.data += delta_w * alpha + loaded_count += 1 @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16,