mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
final fix for ltx-2
This commit is contained in:
@@ -317,7 +317,14 @@ class BasePipeline(torch.nn.Module):
|
|||||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||||
self.clear_lora(verbose=0)
|
self.clear_lora(verbose=0)
|
||||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
if isinstance(noise_pred_posi, tuple):
|
||||||
|
noise_pred = tuple(
|
||||||
|
n_nega + cfg_scale * (n_posi - n_nega)
|
||||||
|
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||||
|
)
|
||||||
|
# Separate handling for dif
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
noise_pred = noise_pred_posi
|
noise_pred = noise_pred_posi
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from enum import Enum
|
|||||||
from typing import Optional, Tuple, Callable
|
from typing import Optional, Tuple, Callable
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch._prims_common import DeviceLikeType
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from .ltx2_common import rms_norm, Modality
|
from .ltx2_common import rms_norm, Modality
|
||||||
from ..core.attention.attention import attention_forward
|
from ..core.attention.attention import attention_forward
|
||||||
@@ -201,7 +200,7 @@ class BatchedPerturbationConfig:
|
|||||||
perturbations: list[PerturbationConfig]
|
perturbations: list[PerturbationConfig]
|
||||||
|
|
||||||
def mask(
|
def mask(
|
||||||
self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
|
self, perturbation_type: PerturbationType, block: int, device, dtype: torch.dtype
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
|
mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
|
||||||
for batch_idx, perturbation in enumerate(self.perturbations):
|
for batch_idx, perturbation in enumerate(self.perturbations):
|
||||||
|
|||||||
@@ -101,26 +101,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
|
||||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
|
||||||
self.clear_lora(verbose=0)
|
|
||||||
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
|
||||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
|
||||||
if cfg_scale != 1.0:
|
|
||||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
|
||||||
self.clear_lora(verbose=0)
|
|
||||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
|
||||||
if isinstance(noise_pred_posi, tuple):
|
|
||||||
noise_pred = tuple(
|
|
||||||
n_nega + cfg_scale * (n_posi - n_nega)
|
|
||||||
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
||||||
else:
|
|
||||||
noise_pred = noise_pred_posi
|
|
||||||
return noise_pred
|
|
||||||
|
|
||||||
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
|
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
|
||||||
if inputs_shared["use_two_stage_pipeline"]:
|
if inputs_shared["use_two_stage_pipeline"]:
|
||||||
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
||||||
@@ -257,7 +237,11 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(take_over=True)
|
super().__init__(
|
||||||
|
take_over=True,
|
||||||
|
input_params=("use_distilled_pipeline", "use_two_stage_pipeline"),
|
||||||
|
output_params=("use_two_stage_pipeline", "cfg_scale")
|
||||||
|
)
|
||||||
|
|
||||||
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||||
if inputs_shared.get("use_distilled_pipeline", False):
|
if inputs_shared.get("use_distilled_pipeline", False):
|
||||||
|
|||||||
Reference in New Issue
Block a user