support lora fusion

This commit is contained in:
Artiprocher
2025-07-03 18:49:46 +08:00
parent 9cb887015b
commit 8a9dbbd3ba
5 changed files with 175 additions and 54 deletions

View File

@@ -124,13 +124,19 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
out = torch.nn.functional.linear(x, weight, bias)
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
out_lora = x @ lora_A.T @ lora_B.T
if self.lora_merger is None:
out = out + out_lora
lora_output.append(out_lora)
if self.lora_merger is not None and len(lora_output) > 0:
if len(self.lora_A_weights) == 0:
# No LoRA
return out
elif self.lora_merger is None:
# Native LoRA inference
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
out = out + x @ lora_A.T @ lora_B.T
else:
# LoRA fusion
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
lora_output.append(x @ lora_A.T @ lora_B.T)
lora_output = torch.stack(lora_output)
out = self.lora_merger(out, lora_output)
return out