mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ltx2.3 inference
This commit is contained in:
@@ -555,9 +555,6 @@ class PerChannelStatistics(nn.Module):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(latent_channels))
|
||||
self.register_buffer("mean-of-means", torch.empty(latent_channels))
|
||||
self.register_buffer("mean-of-stds", torch.empty(latent_channels))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(latent_channels))
|
||||
self.register_buffer("channel", torch.empty(latent_channels))
|
||||
|
||||
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(
|
||||
@@ -1335,27 +1332,49 @@ class LTX2VideoEncoder(nn.Module):
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
|
||||
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
encoder_version: str = "ltx-2",
|
||||
):
|
||||
super().__init__()
|
||||
encoder_blocks = [['res_x', {
|
||||
'num_layers': 4
|
||||
}], ['compress_space_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_time_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}]]
|
||||
if encoder_version == "ltx-2":
|
||||
encoder_blocks = [['res_x', {
|
||||
'num_layers': 4
|
||||
}], ['compress_space_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_time_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}]]
|
||||
else:
|
||||
encoder_blocks = [["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_space_res", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 6
|
||||
}], ["compress_time_res", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_all_res", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}], ["compress_all_res", {
|
||||
"multiplier": 1
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}]]
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
@@ -1435,8 +1454,8 @@ class LTX2VideoEncoder(nn.Module):
|
||||
# Validate frame count
|
||||
frames_count = sample.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames "
|
||||
"(e.g., 1, 9, 17, ...). Please check your input.")
|
||||
frames_to_crop = (frames_count - 1) % 8
|
||||
sample = sample[:, :, :-frames_to_crop, ...]
|
||||
|
||||
# Initial spatial compression: trade spatial resolution for channel depth
|
||||
# This reduces H,W by patch_size and increases channels, making convolutions more efficient
|
||||
@@ -1712,17 +1731,21 @@ def _make_decoder_block(
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(2, 1, 1),
|
||||
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(1, 2, 2),
|
||||
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
@@ -1782,6 +1805,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
causal: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
decoder_version: str = "ltx-2",
|
||||
base_channels: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1790,28 +1815,49 @@ class LTX2VideoDecoder(nn.Module):
|
||||
# video inputs by a factor of 8 in the temporal dimension and 32 in
|
||||
# each spatial dimension (height and width). This parameter determines how
|
||||
# many video frames and pixels correspond to a single latent cell.
|
||||
decoder_blocks = [['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}]]
|
||||
if decoder_version == "ltx-2":
|
||||
decoder_blocks = [['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}]]
|
||||
else:
|
||||
decoder_blocks = [["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_space", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 6
|
||||
}], ["compress_time", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_all", {
|
||||
"multiplier": 1
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}], ["compress_all", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}]]
|
||||
self.video_downscale_factors = SpatioTemporalScaleFactors(
|
||||
time=8,
|
||||
width=32,
|
||||
@@ -1833,13 +1879,14 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
# Compute initial feature_channels by going through blocks in reverse
|
||||
# This determines the channel width at the start of the decoder
|
||||
feature_channels = in_channels
|
||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
block_config = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 1)
|
||||
# feature_channels = in_channels
|
||||
# for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
# block_config = block_params if isinstance(block_params, dict) else {}
|
||||
# if block_name == "res_x_y":
|
||||
# feature_channels = feature_channels * block_config.get("multiplier", 2)
|
||||
# if block_name == "compress_all":
|
||||
# feature_channels = feature_channels * block_config.get("multiplier", 1)
|
||||
feature_channels = base_channels * 8
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims=convolution_dimensions,
|
||||
|
||||
Reference in New Issue
Block a user