ltx2.3 bugfix & ic lora (#1336)

* ltx2.3 ic lora inference&train

* temp commit

* fix first frame train-inference consistency

* minor fix
This commit is contained in:
Hong Zhang
2026-03-09 16:33:19 +08:00
committed by GitHub
parent f7d23c6551
commit 7bc5611fb8
12 changed files with 469 additions and 118 deletions

View File

@@ -1336,45 +1336,30 @@ class LTX2VideoEncoder(nn.Module):
):
super().__init__()
if encoder_version == "ltx-2":
encoder_blocks = [['res_x', {
'num_layers': 4
}], ['compress_space_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_time_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}]]
encoder_blocks = [
['res_x', {'num_layers': 4}],
['compress_space_res', {'multiplier': 2}],
['res_x', {'num_layers': 6}],
['compress_time_res', {'multiplier': 2}],
['res_x', {'num_layers': 6}],
['compress_all_res', {'multiplier': 2}],
['res_x', {'num_layers': 2}],
['compress_all_res', {'multiplier': 2}],
['res_x', {'num_layers': 2}]
]
else:
encoder_blocks = [["res_x", {
"num_layers": 4
}], ["compress_space_res", {
"multiplier": 2
}], ["res_x", {
"num_layers": 6
}], ["compress_time_res", {
"multiplier": 2
}], ["res_x", {
"num_layers": 4
}], ["compress_all_res", {
"multiplier": 2
}], ["res_x", {
"num_layers": 2
}], ["compress_all_res", {
"multiplier": 1
}], ["res_x", {
"num_layers": 2
}]]
# LTX-2.3
encoder_blocks = [
["res_x", {"num_layers": 4}],
["compress_space_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_time_res", {"multiplier": 2}],
["res_x", {"num_layers": 4}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}],
["compress_all_res", {"multiplier": 1}],
["res_x", {"num_layers": 2}]
]
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
@@ -1816,48 +1801,28 @@ class LTX2VideoDecoder(nn.Module):
# each spatial dimension (height and width). This parameter determines how
# many video frames and pixels correspond to a single latent cell.
if decoder_version == "ltx-2":
decoder_blocks = [['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}]]
decoder_blocks = [
['res_x', {'num_layers': 5, 'inject_noise': False}],
['compress_all', {'residual': True, 'multiplier': 2}],
['res_x', {'num_layers': 5, 'inject_noise': False}],
['compress_all', {'residual': True, 'multiplier': 2}],
['res_x', {'num_layers': 5, 'inject_noise': False}],
['compress_all', {'residual': True, 'multiplier': 2}],
['res_x', {'num_layers': 5, 'inject_noise': False}]
]
else:
decoder_blocks = [["res_x", {
"num_layers": 4
}], ["compress_space", {
"multiplier": 2
}], ["res_x", {
"num_layers": 6
}], ["compress_time", {
"multiplier": 2
}], ["res_x", {
"num_layers": 4
}], ["compress_all", {
"multiplier": 1
}], ["res_x", {
"num_layers": 2
}], ["compress_all", {
"multiplier": 2
}], ["res_x", {
"num_layers": 2
}]]
# LTX-2.3
decoder_blocks = [
["res_x", {"num_layers": 4}],
["compress_space", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_time", {"multiplier": 2}],
["res_x", {"num_layers": 4}],
["compress_all", {"multiplier": 1}],
["res_x", {"num_layers": 2}],
["compress_all", {"multiplier": 2}],
["res_x", {"num_layers": 2}]
]
self.video_downscale_factors = SpatioTemporalScaleFactors(
time=8,
width=32,
@@ -1877,15 +1842,8 @@ class LTX2VideoDecoder(nn.Module):
self.decode_noise_scale = 0.025
self.decode_timestep = 0.05
# Compute initial feature_channels by going through blocks in reverse
# This determines the channel width at the start of the decoder
# feature_channels = in_channels
# for block_name, block_params in list(reversed(decoder_blocks)):
# block_config = block_params if isinstance(block_params, dict) else {}
# if block_name == "res_x_y":
# feature_channels = feature_channels * block_config.get("multiplier", 2)
# if block_name == "compress_all":
# feature_channels = feature_channels * block_config.get("multiplier", 1)
# LTX VAE decoder architecture uses 3 upsampler blocks with multiplier equals to 2.
# Hence the total feature_channels is multiplied by 8 (2^3).
feature_channels = base_channels * 8
self.conv_in = make_conv_nd(

View File

@@ -108,18 +108,16 @@ class LTX2AudioVideoPipeline(BasePipeline):
if inputs_shared["use_two_stage_pipeline"]:
if inputs_shared.get("clear_lora_before_state_two", False):
self.clear_lora()
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
latents = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
self.load_models_to_device('upsampler',)
latent = self.upsampler(latent)
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
latents = self.upsampler(latents)
latents = self.video_vae_encoder.per_channel_statistics.normalize(latents)
self.scheduler.set_timesteps(special_case="stage2")
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
denoise_mask_video = 1.0
# input image
if inputs_shared.get("input_images", None) is not None:
latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents(
latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"],
inputs_shared["input_images_strength"], latent.clone())
initial_latents, denoise_mask_video = self.apply_input_images_to_latents(latents, initial_latents=latents, **inputs_shared.get("stage2_input_latents_apply_kwargs", {}))
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
# remove in-context video control in stage 2
inputs_shared.pop("in_context_video_latents", None)
@@ -127,7 +125,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
# initialize latents for stage 2
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latents
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
@@ -157,7 +155,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
denoising_strength: float = 1.0,
# Image-to-video
input_images: Optional[list[Image.Image]] = None,
input_images_indexes: Optional[list[int]] = None,
input_images_indexes: Optional[list[int]] = [0],
input_images_strength: Optional[float] = 1.0,
# In-Context Video Control
in_context_videos: Optional[list[list[Image.Image]]] = None,
@@ -238,17 +236,16 @@ class LTX2AudioVideoPipeline(BasePipeline):
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
return video, decoded_audio
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None):
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, initial_latents=None, denoise_mask_video=None):
b, _, f, h, w = latents.shape
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
for idx, input_latent in zip(input_indexes, input_latents):
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
latents = latents * denoise_mask + initial_latents * (1.0 - denoise_mask)
return latents, denoise_mask, initial_latents
return initial_latents, denoise_mask
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
@@ -414,7 +411,7 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
output_params=("video_latents", "denoise_mask_video", "input_latents_video", "stage2_input_latents"),
output_params=("denoise_mask_video", "input_latents_video", "stage2_input_latents_apply_kwargs"),
onload_model_names=("video_vae_encoder")
)
@@ -423,29 +420,39 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
image = image / 127.5 - 1.0
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
latent = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
return latent
latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
return latents
def get_frame_conditions(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_latents=None, skip_apply=False):
frame_conditions = {}
for img, index in zip(input_images, input_images_indexes):
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
# first_frame
if index == 0:
if skip_apply:
frame_conditions = {"input_latents": [latents], "input_indexes": [0], "input_strength": input_images_strength}
else:
input_latents_video, denoise_mask_video = pipe.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength)
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
return frame_conditions
def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False):
if input_images is None or len(input_images) == 0:
return {"video_latents": video_latents}
return {}
else:
if len(input_images_indexes) != len(set(input_images_indexes)):
raise ValueError("Input images must have unique indexes.")
pipe.load_models_to_device(self.onload_model_names)
output_dicts = {}
stage1_height = height // 2 if use_two_stage_pipeline else height
stage1_width = width // 2 if use_two_stage_pipeline else width
stage1_latents = [
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
tile_overlap_in_pixels) for img in input_images
]
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength)
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
stage_1_frame_conditions = self.get_frame_conditions(pipe, input_images, input_images_indexes, input_images_strength, stage1_height, stage1_width,
tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_latents)
output_dicts.update(stage_1_frame_conditions)
if use_two_stage_pipeline:
stage2_latents = [
self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels,
tile_overlap_in_pixels) for img in input_images
]
output_dicts.update({"stage2_input_latents": stage2_latents})
stage2_input_latents_apply_kwargs = self.get_frame_conditions(pipe, input_images, input_images_indexes, input_images_strength, height, width,
tiled, tile_size_in_pixels, tile_overlap_in_pixels, skip_apply=True)
output_dicts.update({"stage2_input_latents_apply_kwargs": stage2_input_latents_apply_kwargs})
return output_dicts
@@ -508,6 +515,7 @@ def model_fn_ltx2(
audio_positions=None,
audio_patchifier=None,
timestep=None,
input_latents_video=None,
denoise_mask_video=None,
in_context_video_latents=None,
in_context_video_positions=None,
@@ -523,7 +531,9 @@ def model_fn_ltx2(
seq_len_video = video_latents.shape[1]
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
if denoise_mask_video is not None:
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
denoise_mask_video = video_patchifier.patchify(denoise_mask_video)
video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)
video_timesteps = denoise_mask_video * video_timesteps
if in_context_video_latents is not None:
in_context_video_latents = video_patchifier.patchify(in_context_video_latents)