support loading ltx2.3 stage2lora by statedict (#1348)

* support ltx2.3 stage2lora by statedict

* bug fix

* bug fix
This commit is contained in:
Hong Zhang
2026-03-13 17:19:18 +08:00
committed by GitHub
parent 681df93a85
commit 8c9ddc9274
2 changed files with 5 additions and 6 deletions

View File

@@ -417,7 +417,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
def lora_forward(self, x, out):
if self.lora_merger is None:
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
out = out + x @ lora_A.T @ lora_B.T
out = out + x @ lora_A.T.to(device=x.device, dtype=x.dtype) @ lora_B.T.to(device=x.device, dtype=x.dtype)
else:
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):

View File

@@ -138,8 +138,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
# Stage 2
if stage2_lora_config is not None:
stage2_lora_config.download_if_necessary()
pipe.stage2_lora_path = stage2_lora_config.path
pipe.stage2_lora_config = stage2_lora_config
pipe.stage2_lora_strength = stage2_lora_strength
# VRAM Management
@@ -265,8 +264,8 @@ class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
if inputs_shared.get("use_two_stage_pipeline", False):
# distill pipeline also uses two-stage, but it does not needs lora
if not inputs_shared.get("use_distilled_pipeline", False):
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
if not (hasattr(pipe, "stage2_lora_config") and pipe.stage2_lora_config is not None):
raise ValueError("Two-stage pipeline requested, but stage2_lora_config is not set in the pipeline.")
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
return inputs_shared, inputs_posi, inputs_nega
@@ -608,7 +607,7 @@ class LTX2AudioVideoUnit_SwitchStage2(PipelineUnit):
if clear_lora_before_state_two:
pipe.clear_lora()
if not use_distilled_pipeline:
pipe.load_lora(pipe.dit, pipe.stage2_lora_path, alpha=pipe.stage2_lora_strength)
pipe.load_lora(pipe.dit, pipe.stage2_lora_config, alpha=pipe.stage2_lora_strength, state_dict=pipe.stage2_lora_config.state_dict)
return stage2_params