mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
add: LoRA Encoder
This commit is contained in:
@@ -468,80 +468,13 @@ class LoRAEmbedder(torch.nn.Module):
|
||||
"type": suffix,
|
||||
})
|
||||
return lora_patterns
|
||||
|
||||
def get_lora_param_pair(self, lora, name, dim, device, dtype):
|
||||
key_A = name + ".lora_A.default.weight"
|
||||
key_B = name + ".lora_B.default.weight"
|
||||
if key_A in lora and key_B in lora:
|
||||
return lora[key_A], lora[key_B]
|
||||
|
||||
if "to_qkv" in name:
|
||||
base_name = name.replace("to_qkv", "")
|
||||
suffixes = ["to_q", "to_k", "to_v"]
|
||||
|
||||
found_As = []
|
||||
found_Bs = []
|
||||
|
||||
all_found = True
|
||||
for suffix in suffixes:
|
||||
sub_name = base_name + suffix
|
||||
k_A = sub_name + ".lora_A.default.weight"
|
||||
k_B = sub_name + ".lora_B.default.weight"
|
||||
|
||||
if k_A in lora and k_B in lora:
|
||||
found_As.append(lora[k_A])
|
||||
found_Bs.append(lora[k_B])
|
||||
else:
|
||||
all_found = False
|
||||
break
|
||||
if all_found:
|
||||
pass
|
||||
|
||||
rank = 16
|
||||
for k, v in lora.items():
|
||||
if "lora_A" in k:
|
||||
rank = v.shape[0]
|
||||
device = v.device
|
||||
dtype = v.dtype
|
||||
break
|
||||
|
||||
lora_A = torch.zeros((rank, dim[0]), device=device, dtype=dtype)
|
||||
lora_B = torch.zeros((dim[1], rank), device=device, dtype=dtype)
|
||||
|
||||
return lora_A, lora_B
|
||||
|
||||
def forward(self, lora):
|
||||
lora_emb = []
|
||||
device = None
|
||||
dtype = None
|
||||
for v in lora.values():
|
||||
device = v.device
|
||||
dtype = v.dtype
|
||||
break
|
||||
|
||||
for lora_pattern in self.lora_patterns:
|
||||
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||
dim = lora_pattern["dim"]
|
||||
|
||||
lora_A, lora_B = self.get_lora_param_pair(lora, name, dim, device, dtype)
|
||||
|
||||
if "to_qkv" in name and (lora_A is None or (torch.equal(lora_A, torch.zeros_like(lora_A)))):
|
||||
base_name = name.replace("to_qkv", "")
|
||||
try:
|
||||
q_name = base_name + "to_q"
|
||||
k_name = base_name + "to_k"
|
||||
v_name = base_name + "to_v"
|
||||
|
||||
real_A = lora[q_name + ".lora_A.default.weight"]
|
||||
B_q = lora[q_name + ".lora_B.default.weight"]
|
||||
B_k = lora[k_name + ".lora_B.default.weight"]
|
||||
B_v = lora[v_name + ".lora_B.default.weight"]
|
||||
real_B = torch.cat([B_q, B_k, B_v], dim=0)
|
||||
|
||||
lora_A, lora_B = real_A, real_B
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
lora_A = lora[name + ".lora_A.weight"]
|
||||
lora_B = lora[name + ".lora_B.weight"]
|
||||
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||
lora_emb.append(lora_out)
|
||||
|
||||
Reference in New Issue
Block a user