final fix for ltx-2

This commit is contained in:
mi804
2026-02-03 10:39:35 +08:00
parent 2a7ac73eb5
commit 25a9e75030
3 changed files with 14 additions and 24 deletions

View File

@@ -317,7 +317,14 @@ class BasePipeline(torch.nn.Module):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
if isinstance(noise_pred_posi, tuple):
noise_pred = tuple(
n_nega + cfg_scale * (n_posi - n_nega)
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
)
# Separate handling for dif
else:
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
return noise_pred

View File

@@ -5,7 +5,6 @@ from enum import Enum
from typing import Optional, Tuple, Callable
import numpy as np
import torch
from torch._prims_common import DeviceLikeType
from einops import rearrange
from .ltx2_common import rms_norm, Modality
from ..core.attention.attention import attention_forward
@@ -201,7 +200,7 @@ class BatchedPerturbationConfig:
perturbations: list[PerturbationConfig]
def mask(
self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
self, perturbation_type: PerturbationType, block: int, device, dtype: torch.dtype
) -> torch.Tensor:
mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
for batch_idx, perturbation in enumerate(self.perturbations):

View File

@@ -101,26 +101,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
if isinstance(noise_pred_posi, tuple):
noise_pred = tuple(
n_nega + cfg_scale * (n_posi - n_nega)
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
)
else:
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
return noise_pred
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
if inputs_shared["use_two_stage_pipeline"]:
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
@@ -257,7 +237,11 @@ class LTX2AudioVideoPipeline(BasePipeline):
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
def __init__(self):
super().__init__(take_over=True)
super().__init__(
take_over=True,
input_params=("use_distilled_pipeline", "use_two_stage_pipeline"),
output_params=("use_two_stage_pipeline", "cfg_scale")
)
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("use_distilled_pipeline", False):