mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
support loading ltx2.3 stage2lora by statedict (#1348)
* support ltx2.3 stage2lora by statedict * bug fix * bug fix
This commit is contained in:
@@ -417,7 +417,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
|||||||
def lora_forward(self, x, out):
|
def lora_forward(self, x, out):
|
||||||
if self.lora_merger is None:
|
if self.lora_merger is None:
|
||||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||||
out = out + x @ lora_A.T @ lora_B.T
|
out = out + x @ lora_A.T.to(device=x.device, dtype=x.dtype) @ lora_B.T.to(device=x.device, dtype=x.dtype)
|
||||||
else:
|
else:
|
||||||
lora_output = []
|
lora_output = []
|
||||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||||
|
|||||||
@@ -138,8 +138,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
# Stage 2
|
# Stage 2
|
||||||
if stage2_lora_config is not None:
|
if stage2_lora_config is not None:
|
||||||
stage2_lora_config.download_if_necessary()
|
pipe.stage2_lora_config = stage2_lora_config
|
||||||
pipe.stage2_lora_path = stage2_lora_config.path
|
|
||||||
pipe.stage2_lora_strength = stage2_lora_strength
|
pipe.stage2_lora_strength = stage2_lora_strength
|
||||||
|
|
||||||
# VRAM Management
|
# VRAM Management
|
||||||
@@ -265,8 +264,8 @@ class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
|||||||
if inputs_shared.get("use_two_stage_pipeline", False):
|
if inputs_shared.get("use_two_stage_pipeline", False):
|
||||||
# distill pipeline also uses two-stage, but it does not needs lora
|
# distill pipeline also uses two-stage, but it does not needs lora
|
||||||
if not inputs_shared.get("use_distilled_pipeline", False):
|
if not inputs_shared.get("use_distilled_pipeline", False):
|
||||||
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
|
if not (hasattr(pipe, "stage2_lora_config") and pipe.stage2_lora_config is not None):
|
||||||
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
|
raise ValueError("Two-stage pipeline requested, but stage2_lora_config is not set in the pipeline.")
|
||||||
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
|
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
|
||||||
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
|
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
|
||||||
return inputs_shared, inputs_posi, inputs_nega
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
@@ -608,7 +607,7 @@ class LTX2AudioVideoUnit_SwitchStage2(PipelineUnit):
|
|||||||
if clear_lora_before_state_two:
|
if clear_lora_before_state_two:
|
||||||
pipe.clear_lora()
|
pipe.clear_lora()
|
||||||
if not use_distilled_pipeline:
|
if not use_distilled_pipeline:
|
||||||
pipe.load_lora(pipe.dit, pipe.stage2_lora_path, alpha=pipe.stage2_lora_strength)
|
pipe.load_lora(pipe.dit, pipe.stage2_lora_config, alpha=pipe.stage2_lora_strength, state_dict=pipe.stage2_lora_config.state_dict)
|
||||||
return stage2_params
|
return stage2_params
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user