support loading ltx2.3 stage2lora by statedict (#1348)

* support ltx2.3 stage2lora by statedict

* bug fix

* bug fix
This commit is contained in:
Hong Zhang
2026-03-13 17:19:18 +08:00
committed by GitHub
parent 681df93a85
commit 8c9ddc9274
2 changed files with 5 additions and 6 deletions

View File

@@ -417,7 +417,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
def lora_forward(self, x, out):
if self.lora_merger is None:
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
out = out + x @ lora_A.T @ lora_B.T
out = out + x @ lora_A.T.to(device=x.device, dtype=x.dtype) @ lora_B.T.to(device=x.device, dtype=x.dtype)
else:
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):