final fix for ltx-2

This commit is contained in:
mi804
2026-02-03 10:39:35 +08:00
parent 2a7ac73eb5
commit 25a9e75030
3 changed files with 14 additions and 24 deletions

View File

@@ -5,7 +5,6 @@ from enum import Enum
from typing import Optional, Tuple, Callable
import numpy as np
import torch
from torch._prims_common import DeviceLikeType
from einops import rearrange
from .ltx2_common import rms_norm, Modality
from ..core.attention.attention import attention_forward
@@ -201,7 +200,7 @@ class BatchedPerturbationConfig:
perturbations: list[PerturbationConfig]
def mask(
self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
self, perturbation_type: PerturbationType, block: int, device, dtype: torch.dtype
) -> torch.Tensor:
mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
for batch_idx, perturbation in enumerate(self.perturbations):