mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
final fix for ltx-2
This commit is contained in:
@@ -317,7 +317,14 @@ class BasePipeline(torch.nn.Module):
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
if isinstance(noise_pred_posi, tuple):
|
||||
noise_pred = tuple(
|
||||
n_nega + cfg_scale * (n_posi - n_nega)
|
||||
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||
)
|
||||
# Separate handling for dif
|
||||
else:
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
@@ -5,7 +5,6 @@ from enum import Enum
|
||||
from typing import Optional, Tuple, Callable
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch._prims_common import DeviceLikeType
|
||||
from einops import rearrange
|
||||
from .ltx2_common import rms_norm, Modality
|
||||
from ..core.attention.attention import attention_forward
|
||||
@@ -201,7 +200,7 @@ class BatchedPerturbationConfig:
|
||||
perturbations: list[PerturbationConfig]
|
||||
|
||||
def mask(
|
||||
self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
|
||||
self, perturbation_type: PerturbationType, block: int, device, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
|
||||
for batch_idx, perturbation in enumerate(self.perturbations):
|
||||
|
||||
@@ -101,26 +101,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||
if cfg_scale != 1.0:
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
if isinstance(noise_pred_posi, tuple):
|
||||
noise_pred = tuple(
|
||||
n_nega + cfg_scale * (n_posi - n_nega)
|
||||
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||
)
|
||||
else:
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
|
||||
if inputs_shared["use_two_stage_pipeline"]:
|
||||
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
||||
@@ -257,7 +237,11 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
|
||||
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True)
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
input_params=("use_distilled_pipeline", "use_two_stage_pipeline"),
|
||||
output_params=("use_two_stage_pipeline", "cfg_scale")
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if inputs_shared.get("use_distilled_pipeline", False):
|
||||
|
||||
Reference in New Issue
Block a user