From 8c9ddc92749328a7e73b6d5f1d90bcd4259fdb60 Mon Sep 17 00:00:00 2001 From: Hong Zhang <41229682+mi804@users.noreply.github.com> Date: Fri, 13 Mar 2026 17:19:18 +0800 Subject: [PATCH] support loading ltx2.3 stage2lora by statedict (#1348) * support ltx2.3 stage2lora by statedict * bug fix * bug fix --- diffsynth/core/vram/layers.py | 2 +- diffsynth/pipelines/ltx2_audio_video.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/diffsynth/core/vram/layers.py b/diffsynth/core/vram/layers.py index 0f99b0d..7afb360 100644 --- a/diffsynth/core/vram/layers.py +++ b/diffsynth/core/vram/layers.py @@ -417,7 +417,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): def lora_forward(self, x, out): if self.lora_merger is None: for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): - out = out + x @ lora_A.T @ lora_B.T + out = out + x @ lora_A.T.to(device=x.device, dtype=x.dtype) @ lora_B.T.to(device=x.device, dtype=x.dtype) else: lora_output = [] for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 5f78c29..5ef1738 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -138,8 +138,7 @@ class LTX2AudioVideoPipeline(BasePipeline): # Stage 2 if stage2_lora_config is not None: - stage2_lora_config.download_if_necessary() - pipe.stage2_lora_path = stage2_lora_config.path + pipe.stage2_lora_config = stage2_lora_config pipe.stage2_lora_strength = stage2_lora_strength # VRAM Management @@ -265,8 +264,8 @@ class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit): if inputs_shared.get("use_two_stage_pipeline", False): # distill pipeline also uses two-stage, but it does not needs lora if not inputs_shared.get("use_distilled_pipeline", False): - if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None): - raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.") + if not (hasattr(pipe, "stage2_lora_config") and pipe.stage2_lora_config is not None): + raise ValueError("Two-stage pipeline requested, but stage2_lora_config is not set in the pipeline.") if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None): raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.") return inputs_shared, inputs_posi, inputs_nega @@ -608,7 +607,7 @@ class LTX2AudioVideoUnit_SwitchStage2(PipelineUnit): if clear_lora_before_state_two: pipe.clear_lora() if not use_distilled_pipeline: - pipe.load_lora(pipe.dit, pipe.stage2_lora_path, alpha=pipe.stage2_lora_strength) + pipe.load_lora(pipe.dit, pipe.stage2_lora_config, alpha=pipe.stage2_lora_strength, state_dict=pipe.stage2_lora_config.state_dict) return stage2_params