mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
lora hotload and merge
This commit is contained in:
@@ -107,6 +107,9 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
self.vram_limit = vram_limit
|
||||
self.state = 0
|
||||
self.name = name
|
||||
self.lora_A_weights = []
|
||||
self.lora_B_weights = []
|
||||
self.lora_merger = None
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if self.state == 2:
|
||||
@@ -120,7 +123,17 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
else:
|
||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
out = torch.nn.functional.linear(x, weight, bias)
|
||||
lora_output = []
|
||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||
out_lora = x @ lora_A.T @ lora_B.T
|
||||
if self.lora_merger is None:
|
||||
out = out + out_lora
|
||||
lora_output.append(out_lora)
|
||||
if self.lora_merger is not None and len(lora_output) > 0:
|
||||
lora_output = torch.stack(lora_output)
|
||||
out = self.lora_merger(out, lora_output)
|
||||
return out
|
||||
|
||||
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""):
|
||||
|
||||
Reference in New Issue
Block a user