lora merger

This commit is contained in:
Artiprocher
2025-04-21 15:48:25 +08:00
parent 04260801a2
commit 44da204dbd
7 changed files with 516 additions and 30 deletions

View File

@@ -71,15 +71,16 @@ class AutoWrappedLinear(torch.nn.Linear):
return torch.nn.functional.linear(x, weight, bias)
class AutoLoRALinear(torch.nn.Linear):
def __init__(self, name='', in_features=1, out_features=2, bias = True, device=None, dtype=None):
def __init__(self, name='', in_features=1, out_features=2, bias=True, device=None, dtype=None):
super().__init__(in_features, out_features, bias, device, dtype)
self.name = name
def forward(self, x, lora_state_dicts=[], lora_alpahs=[1.0,1.0], **kwargs):
def forward(self, x, lora_state_dicts=[], lora_alpahs=[1.0,1.0], lora_patcher=None, **kwargs):
out = torch.nn.functional.linear(x, self.weight, self.bias)
lora_a_name = f'{self.name}.lora_A.weight'
lora_b_name = f'{self.name}.lora_B.weight'
lora_a_name = f'{self.name}.lora_A.default.weight'
lora_b_name = f'{self.name}.lora_B.default.weight'
lora_output = []
for i, lora_state_dict in enumerate(lora_state_dicts):
if lora_state_dict is None:
break
@@ -87,7 +88,10 @@ class AutoLoRALinear(torch.nn.Linear):
lora_A = lora_state_dict[lora_a_name].to(dtype=self.weight.dtype,device=self.weight.device)
lora_B = lora_state_dict[lora_b_name].to(dtype=self.weight.dtype,device=self.weight.device)
out_lora = x @ lora_A.T @ lora_B.T
out = out + out_lora * lora_alpahs[i]
lora_output.append(out_lora)
if len(lora_output) > 0:
lora_output = torch.stack(lora_output)
out = lora_patcher(out, lora_output, self.name)
return out
def enable_auto_lora(model:torch.nn.Module, module_map: dict, name_prefix=''):