ltx2.3 bugfix & ic lora (#1336)

* ltx2.3 ic lora inference&train

* temp commit

* fix first frame train-inference consistency

* minor fix
This commit is contained in:
Hong Zhang
2026-03-09 16:33:19 +08:00
committed by GitHub
parent f7d23c6551
commit 7bc5611fb8
12 changed files with 469 additions and 118 deletions

View File

@@ -108,18 +108,16 @@ class LTX2AudioVideoPipeline(BasePipeline):
if inputs_shared["use_two_stage_pipeline"]:
if inputs_shared.get("clear_lora_before_state_two", False):
self.clear_lora()
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
latents = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
self.load_models_to_device('upsampler',)
latent = self.upsampler(latent)
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
latents = self.upsampler(latents)
latents = self.video_vae_encoder.per_channel_statistics.normalize(latents)
self.scheduler.set_timesteps(special_case="stage2")
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
denoise_mask_video = 1.0
# input image
if inputs_shared.get("input_images", None) is not None:
latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents(
latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"],
inputs_shared["input_images_strength"], latent.clone())
initial_latents, denoise_mask_video = self.apply_input_images_to_latents(latents, initial_latents=latents, **inputs_shared.get("stage2_input_latents_apply_kwargs", {}))
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
# remove in-context video control in stage 2
inputs_shared.pop("in_context_video_latents", None)
@@ -127,7 +125,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
# initialize latents for stage 2
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latents
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
@@ -157,7 +155,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
denoising_strength: float = 1.0,
# Image-to-video
input_images: Optional[list[Image.Image]] = None,
input_images_indexes: Optional[list[int]] = None,
input_images_indexes: Optional[list[int]] = [0],
input_images_strength: Optional[float] = 1.0,
# In-Context Video Control
in_context_videos: Optional[list[list[Image.Image]]] = None,
@@ -238,17 +236,16 @@ class LTX2AudioVideoPipeline(BasePipeline):
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
return video, decoded_audio
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None):
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, initial_latents=None, denoise_mask_video=None):
b, _, f, h, w = latents.shape
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
for idx, input_latent in zip(input_indexes, input_latents):
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
latents = latents * denoise_mask + initial_latents * (1.0 - denoise_mask)
return latents, denoise_mask, initial_latents
return initial_latents, denoise_mask
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
@@ -414,7 +411,7 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
output_params=("video_latents", "denoise_mask_video", "input_latents_video", "stage2_input_latents"),
output_params=("denoise_mask_video", "input_latents_video", "stage2_input_latents_apply_kwargs"),
onload_model_names=("video_vae_encoder")
)
@@ -423,29 +420,39 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
image = image / 127.5 - 1.0
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
latent = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
return latent
latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
return latents
def get_frame_conditions(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_latents=None, skip_apply=False):
frame_conditions = {}
for img, index in zip(input_images, input_images_indexes):
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
# first_frame
if index == 0:
if skip_apply:
frame_conditions = {"input_latents": [latents], "input_indexes": [0], "input_strength": input_images_strength}
else:
input_latents_video, denoise_mask_video = pipe.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength)
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
return frame_conditions
def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False):
if input_images is None or len(input_images) == 0:
return {"video_latents": video_latents}
return {}
else:
if len(input_images_indexes) != len(set(input_images_indexes)):
raise ValueError("Input images must have unique indexes.")
pipe.load_models_to_device(self.onload_model_names)
output_dicts = {}
stage1_height = height // 2 if use_two_stage_pipeline else height
stage1_width = width // 2 if use_two_stage_pipeline else width
stage1_latents = [
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
tile_overlap_in_pixels) for img in input_images
]
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength)
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
stage_1_frame_conditions = self.get_frame_conditions(pipe, input_images, input_images_indexes, input_images_strength, stage1_height, stage1_width,
tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_latents)
output_dicts.update(stage_1_frame_conditions)
if use_two_stage_pipeline:
stage2_latents = [
self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels,
tile_overlap_in_pixels) for img in input_images
]
output_dicts.update({"stage2_input_latents": stage2_latents})
stage2_input_latents_apply_kwargs = self.get_frame_conditions(pipe, input_images, input_images_indexes, input_images_strength, height, width,
tiled, tile_size_in_pixels, tile_overlap_in_pixels, skip_apply=True)
output_dicts.update({"stage2_input_latents_apply_kwargs": stage2_input_latents_apply_kwargs})
return output_dicts
@@ -508,6 +515,7 @@ def model_fn_ltx2(
audio_positions=None,
audio_patchifier=None,
timestep=None,
input_latents_video=None,
denoise_mask_video=None,
in_context_video_latents=None,
in_context_video_positions=None,
@@ -523,7 +531,9 @@ def model_fn_ltx2(
seq_len_video = video_latents.shape[1]
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
if denoise_mask_video is not None:
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
denoise_mask_video = video_patchifier.patchify(denoise_mask_video)
video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)
video_timesteps = denoise_mask_video * video_timesteps
if in_context_video_latents is not None:
in_context_video_latents = video_patchifier.patchify(in_context_video_latents)