mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 16:18:13 +00:00
support lora fusion
This commit is contained in:
@@ -124,13 +124,19 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
out = torch.nn.functional.linear(x, weight, bias)
|
||||
lora_output = []
|
||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||
out_lora = x @ lora_A.T @ lora_B.T
|
||||
if self.lora_merger is None:
|
||||
out = out + out_lora
|
||||
lora_output.append(out_lora)
|
||||
if self.lora_merger is not None and len(lora_output) > 0:
|
||||
|
||||
if len(self.lora_A_weights) == 0:
|
||||
# No LoRA
|
||||
return out
|
||||
elif self.lora_merger is None:
|
||||
# Native LoRA inference
|
||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||
out = out + x @ lora_A.T @ lora_B.T
|
||||
else:
|
||||
# LoRA fusion
|
||||
lora_output = []
|
||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||
lora_output.append(x @ lora_A.T @ lora_B.T)
|
||||
lora_output = torch.stack(lora_output)
|
||||
out = self.lora_merger(out, lora_output)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user