ltx2.3 bugfix & ic lora (#1336)

* ltx2.3 ic lora inference&train

* temp commit

* fix first frame train-inference consistency

* minor fix
This commit is contained in:
Hong Zhang
2026-03-09 16:33:19 +08:00
committed by GitHub
parent f7d23c6551
commit 7bc5611fb8
12 changed files with 469 additions and 118 deletions

View File

@@ -1336,45 +1336,30 @@ class LTX2VideoEncoder(nn.Module):
):
super().__init__()
if encoder_version == "ltx-2":
encoder_blocks = [['res_x', {
'num_layers': 4
}], ['compress_space_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_time_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}]]
encoder_blocks = [
['res_x', {'num_layers': 4}],
['compress_space_res', {'multiplier': 2}],
['res_x', {'num_layers': 6}],
['compress_time_res', {'multiplier': 2}],
['res_x', {'num_layers': 6}],
['compress_all_res', {'multiplier': 2}],
['res_x', {'num_layers': 2}],
['compress_all_res', {'multiplier': 2}],
['res_x', {'num_layers': 2}]
]
else:
encoder_blocks = [["res_x", {
"num_layers": 4
}], ["compress_space_res", {
"multiplier": 2
}], ["res_x", {
"num_layers": 6
}], ["compress_time_res", {
"multiplier": 2
}], ["res_x", {
"num_layers": 4
}], ["compress_all_res", {
"multiplier": 2
}], ["res_x", {
"num_layers": 2
}], ["compress_all_res", {
"multiplier": 1
}], ["res_x", {
"num_layers": 2
}]]
# LTX-2.3
encoder_blocks = [
["res_x", {"num_layers": 4}],
["compress_space_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_time_res", {"multiplier": 2}],
["res_x", {"num_layers": 4}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}],
["compress_all_res", {"multiplier": 1}],
["res_x", {"num_layers": 2}]
]
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
@@ -1816,48 +1801,28 @@ class LTX2VideoDecoder(nn.Module):
# each spatial dimension (height and width). This parameter determines how
# many video frames and pixels correspond to a single latent cell.
if decoder_version == "ltx-2":
decoder_blocks = [['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}]]
decoder_blocks = [
['res_x', {'num_layers': 5, 'inject_noise': False}],
['compress_all', {'residual': True, 'multiplier': 2}],
['res_x', {'num_layers': 5, 'inject_noise': False}],
['compress_all', {'residual': True, 'multiplier': 2}],
['res_x', {'num_layers': 5, 'inject_noise': False}],
['compress_all', {'residual': True, 'multiplier': 2}],
['res_x', {'num_layers': 5, 'inject_noise': False}]
]
else:
decoder_blocks = [["res_x", {
"num_layers": 4
}], ["compress_space", {
"multiplier": 2
}], ["res_x", {
"num_layers": 6
}], ["compress_time", {
"multiplier": 2
}], ["res_x", {
"num_layers": 4
}], ["compress_all", {
"multiplier": 1
}], ["res_x", {
"num_layers": 2
}], ["compress_all", {
"multiplier": 2
}], ["res_x", {
"num_layers": 2
}]]
# LTX-2.3
decoder_blocks = [
["res_x", {"num_layers": 4}],
["compress_space", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_time", {"multiplier": 2}],
["res_x", {"num_layers": 4}],
["compress_all", {"multiplier": 1}],
["res_x", {"num_layers": 2}],
["compress_all", {"multiplier": 2}],
["res_x", {"num_layers": 2}]
]
self.video_downscale_factors = SpatioTemporalScaleFactors(
time=8,
width=32,
@@ -1877,15 +1842,8 @@ class LTX2VideoDecoder(nn.Module):
self.decode_noise_scale = 0.025
self.decode_timestep = 0.05
# Compute initial feature_channels by going through blocks in reverse
# This determines the channel width at the start of the decoder
# feature_channels = in_channels
# for block_name, block_params in list(reversed(decoder_blocks)):
# block_config = block_params if isinstance(block_params, dict) else {}
# if block_name == "res_x_y":
# feature_channels = feature_channels * block_config.get("multiplier", 2)
# if block_name == "compress_all":
# feature_channels = feature_channels * block_config.get("multiplier", 1)
# LTX VAE decoder architecture uses 3 upsampler blocks with multiplier equals to 2.
# Hence the total feature_channels is multiplied by 8 (2^3).
feature_channels = base_channels * 8
self.conv_in = make_conv_nd(