mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support loading ltx2.3 stage2lora by statedict (#1348)
* support ltx2.3 stage2lora by statedict * bug fix * bug fix
This commit is contained in:
@@ -417,7 +417,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
def lora_forward(self, x, out):
|
||||
if self.lora_merger is None:
|
||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||
out = out + x @ lora_A.T @ lora_B.T
|
||||
out = out + x @ lora_A.T.to(device=x.device, dtype=x.dtype) @ lora_B.T.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
lora_output = []
|
||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||
|
||||
@@ -138,8 +138,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
|
||||
# Stage 2
|
||||
if stage2_lora_config is not None:
|
||||
stage2_lora_config.download_if_necessary()
|
||||
pipe.stage2_lora_path = stage2_lora_config.path
|
||||
pipe.stage2_lora_config = stage2_lora_config
|
||||
pipe.stage2_lora_strength = stage2_lora_strength
|
||||
|
||||
# VRAM Management
|
||||
@@ -265,8 +264,8 @@ class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
||||
if inputs_shared.get("use_two_stage_pipeline", False):
|
||||
# distill pipeline also uses two-stage, but it does not needs lora
|
||||
if not inputs_shared.get("use_distilled_pipeline", False):
|
||||
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
|
||||
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
|
||||
if not (hasattr(pipe, "stage2_lora_config") and pipe.stage2_lora_config is not None):
|
||||
raise ValueError("Two-stage pipeline requested, but stage2_lora_config is not set in the pipeline.")
|
||||
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
|
||||
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
@@ -608,7 +607,7 @@ class LTX2AudioVideoUnit_SwitchStage2(PipelineUnit):
|
||||
if clear_lora_before_state_two:
|
||||
pipe.clear_lora()
|
||||
if not use_distilled_pipeline:
|
||||
pipe.load_lora(pipe.dit, pipe.stage2_lora_path, alpha=pipe.stage2_lora_strength)
|
||||
pipe.load_lora(pipe.dit, pipe.stage2_lora_config, alpha=pipe.stage2_lora_strength, state_dict=pipe.stage2_lora_config.state_dict)
|
||||
return stage2_params
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user