diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index d4731fd..bf74365 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -317,7 +317,14 @@ class BasePipeline(torch.nn.Module): if inputs_shared.get("positive_only_lora", None) is not None: self.clear_lora(verbose=0) noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) - noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + if isinstance(noise_pred_posi, tuple): + noise_pred = tuple( + n_nega + cfg_scale * (n_posi - n_nega) + for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega) + ) + # Separate handling for dif + else: + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi return noise_pred diff --git a/diffsynth/models/ltx2_dit.py b/diffsynth/models/ltx2_dit.py index 5f9f856..2e3c958 100644 --- a/diffsynth/models/ltx2_dit.py +++ b/diffsynth/models/ltx2_dit.py @@ -5,7 +5,6 @@ from enum import Enum from typing import Optional, Tuple, Callable import numpy as np import torch -from torch._prims_common import DeviceLikeType from einops import rearrange from .ltx2_common import rms_norm, Modality from ..core.attention.attention import attention_forward @@ -201,7 +200,7 @@ class BatchedPerturbationConfig: perturbations: list[PerturbationConfig] def mask( - self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype + self, perturbation_type: PerturbationType, block: int, device, dtype: torch.dtype ) -> torch.Tensor: mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype) for batch_idx, perturbation in enumerate(self.perturbations): diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 25f6ecd..9ed48aa 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -101,26 +101,6 @@ class LTX2AudioVideoPipeline(BasePipeline): pipe.vram_management_enabled = pipe.check_vram_management_state() return pipe - def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): - if inputs_shared.get("positive_only_lora", None) is not None: - self.clear_lora(verbose=0) - self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0) - noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) - if cfg_scale != 1.0: - if inputs_shared.get("positive_only_lora", None) is not None: - self.clear_lora(verbose=0) - noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) - if isinstance(noise_pred_posi, tuple): - noise_pred = tuple( - n_nega + cfg_scale * (n_posi - n_nega) - for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega) - ) - else: - noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) - else: - noise_pred = noise_pred_posi - return noise_pred - def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm): if inputs_shared["use_two_stage_pipeline"]: latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"]) @@ -257,7 +237,11 @@ class LTX2AudioVideoPipeline(BasePipeline): class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit): def __init__(self): - super().__init__(take_over=True) + super().__init__( + take_over=True, + input_params=("use_distilled_pipeline", "use_two_stage_pipeline"), + output_params=("use_two_stage_pipeline", "cfg_scale") + ) def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega): if inputs_shared.get("use_distilled_pipeline", False):