mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
fix:flux
This commit is contained in:
@@ -106,41 +106,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
def enable_lora_magic(self):
|
def enable_lora_magic(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# def load_lora(self, model, lora_config, alpha=1, hotload=False):
|
def load_lora(self, model, lora_config, alpha=1, hotload=False):
|
||||||
# if isinstance(lora_config, str):
|
|
||||||
# path = lora_config
|
|
||||||
# else:
|
|
||||||
# lora_config.download_if_necessary()
|
|
||||||
# path = lora_config.path
|
|
||||||
|
|
||||||
# state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu")
|
|
||||||
# loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
|
||||||
# state_dict = loader.convert_state_dict(state_dict)
|
|
||||||
# loaded_count = 0
|
|
||||||
# for key in tqdm(state_dict, desc="Applying LoRA"):
|
|
||||||
# if ".lora_A." in key:
|
|
||||||
# layer_name = key.split(".lora_A.")[0]
|
|
||||||
# module = model
|
|
||||||
# try:
|
|
||||||
# parts = layer_name.split(".")
|
|
||||||
# for part in parts:
|
|
||||||
# if part.isdigit():
|
|
||||||
# module = module[int(part)]
|
|
||||||
# else:
|
|
||||||
# module = getattr(module, part)
|
|
||||||
# except AttributeError:
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype)
|
|
||||||
# w_b_key = key.replace("lora_A", "lora_B")
|
|
||||||
# if w_b_key not in state_dict: continue
|
|
||||||
# w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype)
|
|
||||||
# delta_w = torch.mm(w_b, w_a)
|
|
||||||
# module.weight.data += delta_w * alpha
|
|
||||||
# loaded_count += 1
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(self, model, lora_config, alpha=1.0, hotload=False):
|
|
||||||
if isinstance(lora_config, str):
|
if isinstance(lora_config, str):
|
||||||
path = lora_config
|
path = lora_config
|
||||||
else:
|
else:
|
||||||
@@ -150,74 +116,28 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu")
|
state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu")
|
||||||
loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
||||||
state_dict = loader.convert_state_dict(state_dict)
|
state_dict = loader.convert_state_dict(state_dict)
|
||||||
|
|
||||||
print(f"Merging LoRA weights from {path}...")
|
|
||||||
loaded_count = 0
|
loaded_count = 0
|
||||||
|
|
||||||
# [新增] 键名映射表,处理 FW2 Loader 与 DiT 模型名称不一致的情况
|
|
||||||
# 针对 Single Blocks 常见的命名差异进行修正
|
|
||||||
key_mapping = {
|
|
||||||
".linear1.": ".to_qkv_mlp.", # 常见差异点 1
|
|
||||||
".linear2.": ".proj_out.", # 常见差异点 2
|
|
||||||
".modulation.lin.": ".norm.linear." # 常见差异点 3
|
|
||||||
}
|
|
||||||
|
|
||||||
for key in tqdm(state_dict, desc="Applying LoRA"):
|
for key in tqdm(state_dict, desc="Applying LoRA"):
|
||||||
if ".lora_A." in key:
|
if ".lora_A." in key:
|
||||||
layer_name = key.split(".lora_A.")[0]
|
layer_name = key.split(".lora_A.")[0]
|
||||||
|
|
||||||
# [新增] 尝试应用键名修正
|
|
||||||
target_layer_name = layer_name
|
|
||||||
for src, dst in key_mapping.items():
|
|
||||||
if src in target_layer_name:
|
|
||||||
target_layer_name = target_layer_name.replace(src, dst)
|
|
||||||
|
|
||||||
# 在模型中查找层
|
|
||||||
module = model
|
module = model
|
||||||
try:
|
try:
|
||||||
parts = target_layer_name.split(".")
|
parts = layer_name.split(".")
|
||||||
for part in parts:
|
for part in parts:
|
||||||
if part.isdigit():
|
if part.isdigit():
|
||||||
module = module[int(part)]
|
module = module[int(part)]
|
||||||
else:
|
else:
|
||||||
module = getattr(module, part)
|
module = getattr(module, part)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# 如果修正后还是找不到,尝试原始名称(作为保底)
|
|
||||||
try:
|
|
||||||
module = model
|
|
||||||
parts = layer_name.split(".")
|
|
||||||
for part in parts:
|
|
||||||
if part.isdigit():
|
|
||||||
module = module[int(part)]
|
|
||||||
else:
|
|
||||||
module = getattr(module, part)
|
|
||||||
except AttributeError:
|
|
||||||
# 确实找不到,跳过并打印警告(可选)
|
|
||||||
# print(f"Warning: Could not find layer for {layer_name}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 获取 LoRA 参数并计算增量
|
|
||||||
try:
|
|
||||||
w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype)
|
|
||||||
w_b_key = key.replace("lora_A", "lora_B")
|
|
||||||
if w_b_key not in state_dict: continue
|
|
||||||
w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype)
|
|
||||||
|
|
||||||
# 检查形状是否匹配 (非常重要,防止 broadcasting 错误掩盖问题)
|
|
||||||
# Linear weight: (out, in). B@A: (out, in)
|
|
||||||
delta_w = torch.mm(w_b, w_a)
|
|
||||||
if delta_w.shape != module.weight.shape:
|
|
||||||
# 形状不匹配通常意味着 QKV 融合/分离状态不一致
|
|
||||||
# 简单跳过或尝试转置(视具体情况,这里保守跳过)
|
|
||||||
continue
|
|
||||||
|
|
||||||
module.weight.data += delta_w * alpha
|
|
||||||
loaded_count += 1
|
|
||||||
except Exception as e:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"Applied LoRA to {loaded_count} layers.")
|
w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype)
|
||||||
|
w_b_key = key.replace("lora_A", "lora_B")
|
||||||
|
if w_b_key not in state_dict: continue
|
||||||
|
w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype)
|
||||||
|
delta_w = torch.mm(w_b, w_a)
|
||||||
|
module.weight.data += delta_w * alpha
|
||||||
|
loaded_count += 1
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
torch_dtype: torch.dtype = torch.bfloat16,
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
|||||||
Reference in New Issue
Block a user