lora hotload and merge

This commit is contained in:
lzw478614@alibaba-inc.com
2025-07-02 13:32:24 +08:00
parent d9c812818d
commit 9cb887015b
3 changed files with 114 additions and 3 deletions

View File

@@ -107,6 +107,9 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.vram_limit = vram_limit
self.state = 0
self.name = name
self.lora_A_weights = []
self.lora_B_weights = []
self.lora_merger = None
def forward(self, x, *args, **kwargs):
if self.state == 2:
@@ -120,7 +123,17 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
return torch.nn.functional.linear(x, weight, bias)
out = torch.nn.functional.linear(x, weight, bias)
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
out_lora = x @ lora_A.T @ lora_B.T
if self.lora_merger is None:
out = out + out_lora
lora_output.append(out_lora)
if self.lora_merger is not None and len(lora_output) > 0:
lora_output = torch.stack(lora_output)
out = self.lora_merger(out, lora_output)
return out
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""):