mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:39:43 +00:00
Fix LoRA compatibility issues. (#1320)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import torch
|
||||
import torch, warnings
|
||||
|
||||
|
||||
class GeneralLoRALoader:
|
||||
@@ -26,7 +26,11 @@ class GeneralLoRALoader:
|
||||
keys.pop(0)
|
||||
keys.pop(-1)
|
||||
target_name = ".".join(keys)
|
||||
lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key))
|
||||
# Alpha: Deprecated but retained for compatibility.
|
||||
key_alpha = key.replace(lora_B_key + ".weight", "alpha").replace(lora_B_key + ".default.weight", "alpha")
|
||||
if key_alpha == key or key_alpha not in lora_state_dict:
|
||||
key_alpha = None
|
||||
lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key), key_alpha)
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
@@ -36,6 +40,10 @@ class GeneralLoRALoader:
|
||||
for name in name_dict:
|
||||
weight_up = state_dict[name_dict[name][0]]
|
||||
weight_down = state_dict[name_dict[name][1]]
|
||||
if name_dict[name][2] is not None:
|
||||
warnings.warn("Alpha detected in the LoRA file. This may be a LoRA model not trained by DiffSynth-Studio. To ensure compatibility, the LoRA weights will be converted to weight * alpha / rank.")
|
||||
alpha = state_dict[name_dict[name][2]] / weight_down.shape[0]
|
||||
weight_down = weight_down * alpha
|
||||
state_dict_[name + f".lora_B{suffix}"] = weight_up
|
||||
state_dict_[name + f".lora_A{suffix}"] = weight_down
|
||||
return state_dict_
|
||||
|
||||
Reference in New Issue
Block a user