support ltx2.3 inference

This commit is contained in:
mi804
2026-03-06 16:07:17 +08:00
parent c5aaa1da41
commit 73b13f4c86
17 changed files with 1608 additions and 351 deletions

View File

@@ -555,9 +555,6 @@ class PerChannelStatistics(nn.Module):
super().__init__()
self.register_buffer("std-of-means", torch.empty(latent_channels))
self.register_buffer("mean-of-means", torch.empty(latent_channels))
self.register_buffer("mean-of-stds", torch.empty(latent_channels))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(latent_channels))
self.register_buffer("channel", torch.empty(latent_channels))
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(
@@ -1335,27 +1332,49 @@ class LTX2VideoEncoder(nn.Module):
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
encoder_version: str = "ltx-2",
):
super().__init__()
encoder_blocks = [['res_x', {
'num_layers': 4
}], ['compress_space_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_time_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}]]
if encoder_version == "ltx-2":
encoder_blocks = [['res_x', {
'num_layers': 4
}], ['compress_space_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_time_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}]]
else:
encoder_blocks = [["res_x", {
"num_layers": 4
}], ["compress_space_res", {
"multiplier": 2
}], ["res_x", {
"num_layers": 6
}], ["compress_time_res", {
"multiplier": 2
}], ["res_x", {
"num_layers": 4
}], ["compress_all_res", {
"multiplier": 2
}], ["res_x", {
"num_layers": 2
}], ["compress_all_res", {
"multiplier": 1
}], ["res_x", {
"num_layers": 2
}]]
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
@@ -1435,8 +1454,8 @@ class LTX2VideoEncoder(nn.Module):
# Validate frame count
frames_count = sample.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames "
"(e.g., 1, 9, 17, ...). Please check your input.")
frames_to_crop = (frames_count - 1) % 8
sample = sample[:, :, :-frames_to_crop, ...]
# Initial spatial compression: trade spatial resolution for channel depth
# This reduces H,W by patch_size and increases channels, making convolutions more efficient
@@ -1712,17 +1731,21 @@ def _make_decoder_block(
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
out_channels = in_channels // block_config.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(2, 1, 1),
out_channels_reduction_factor=block_config.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
out_channels = in_channels // block_config.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(1, 2, 2),
out_channels_reduction_factor=block_config.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
@@ -1782,6 +1805,8 @@ class LTX2VideoDecoder(nn.Module):
causal: bool = False,
timestep_conditioning: bool = False,
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
decoder_version: str = "ltx-2",
base_channels: int = 128,
):
super().__init__()
@@ -1790,28 +1815,49 @@ class LTX2VideoDecoder(nn.Module):
# video inputs by a factor of 8 in the temporal dimension and 32 in
# each spatial dimension (height and width). This parameter determines how
# many video frames and pixels correspond to a single latent cell.
decoder_blocks = [['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}]]
if decoder_version == "ltx-2":
decoder_blocks = [['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}]]
else:
decoder_blocks = [["res_x", {
"num_layers": 4
}], ["compress_space", {
"multiplier": 2
}], ["res_x", {
"num_layers": 6
}], ["compress_time", {
"multiplier": 2
}], ["res_x", {
"num_layers": 4
}], ["compress_all", {
"multiplier": 1
}], ["res_x", {
"num_layers": 2
}], ["compress_all", {
"multiplier": 2
}], ["res_x", {
"num_layers": 2
}]]
self.video_downscale_factors = SpatioTemporalScaleFactors(
time=8,
width=32,
@@ -1833,13 +1879,14 @@ class LTX2VideoDecoder(nn.Module):
# Compute initial feature_channels by going through blocks in reverse
# This determines the channel width at the start of the decoder
feature_channels = in_channels
for block_name, block_params in list(reversed(decoder_blocks)):
block_config = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
feature_channels = feature_channels * block_config.get("multiplier", 2)
if block_name == "compress_all":
feature_channels = feature_channels * block_config.get("multiplier", 1)
# feature_channels = in_channels
# for block_name, block_params in list(reversed(decoder_blocks)):
# block_config = block_params if isinstance(block_params, dict) else {}
# if block_name == "res_x_y":
# feature_channels = feature_channels * block_config.get("multiplier", 2)
# if block_name == "compress_all":
# feature_channels = feature_channels * block_config.get("multiplier", 1)
feature_channels = base_channels * 8
self.conv_in = make_conv_nd(
dims=convolution_dimensions,