Merge pull request #1343 from mi804/ltx2.3_multiref

Ltx2.3 multiref
This commit is contained in:
Hong Zhang
2026-03-10 17:31:05 +08:00
committed by GitHub
parent f3ebd6f714
commit c927062546
4 changed files with 113 additions and 36 deletions

View File

@@ -141,7 +141,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False):
if skip_stage:
return inputs_shared, inputs_posi, inputs_nega
@@ -160,7 +159,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
@torch.no_grad()
def __call__(
self,
@@ -231,7 +229,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
video = self.vae_output_to_video(video)
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
decoded_audio = self.audio_vocoder(decoded_audio)
decoded_audio = self.output_audio_format_check(decoded_audio)
return video, decoded_audio
@@ -394,8 +393,8 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"),
output_params=("denoise_mask_video", "input_latents_video"),
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "frame_rate", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"),
output_params=("denoise_mask_video", "input_latents_video", "ref_frames_latents", "ref_frames_positions"),
onload_model_names=("video_vae_encoder")
)
@@ -418,20 +417,30 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
return initial_latents, denoise_mask
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None):
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, frame_rate, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None):
if input_images is None or len(input_images) == 0:
return {}
else:
if len(input_images_indexes) != len(set(input_images_indexes)):
raise ValueError("Input images must have unique indexes.")
pipe.load_models_to_device(self.onload_model_names)
frame_conditions = {}
frame_conditions = {"input_latents_video": None, "denoise_mask_video": None, "ref_frames_latents": [], "ref_frames_positions": []}
for img, index in zip(input_images, input_images_indexes):
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
# first_frame
# first_frame by replacing latents
if index == 0:
input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength, initial_latents)
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
# other frames by adding reference latents
else:
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(latents.shape), device=pipe.device)
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, False).float()
video_positions[:, 0, ...] = (video_positions[:, 0, ...] + index) / frame_rate
video_positions = video_positions.to(pipe.torch_dtype)
frame_conditions["ref_frames_latents"].append(latents)
frame_conditions["ref_frames_positions"].append(video_positions)
if len(frame_conditions["ref_frames_latents"]) == 0:
frame_conditions.update({"ref_frames_latents": None, "ref_frames_positions": None})
return frame_conditions
@@ -553,6 +562,9 @@ def model_fn_ltx2(
# First Frame Conditioning
input_latents_video=None,
denoise_mask_video=None,
# Other Frames Conditioning
ref_frames_latents=None,
ref_frames_positions=None,
# In-Context Conditioning
in_context_video_latents=None,
in_context_video_positions=None,
@@ -568,17 +580,24 @@ def model_fn_ltx2(
video_latents = video_patchifier.patchify(video_latents)
seq_len_video = video_latents.shape[1]
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
if denoise_mask_video is not None:
# Frist frame conditioning by replacing the video latents
if input_latents_video is not None:
denoise_mask_video = video_patchifier.patchify(denoise_mask_video)
video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)
video_timesteps = denoise_mask_video * video_timesteps
if in_context_video_latents is not None:
in_context_video_latents = video_patchifier.patchify(in_context_video_latents)
in_context_video_timesteps = timestep.repeat(1, in_context_video_latents.shape[1], 1) * 0.
video_latents = torch.cat([video_latents, in_context_video_latents], dim=1)
video_positions = torch.cat([video_positions, in_context_video_positions], dim=2)
video_timesteps = torch.cat([video_timesteps, in_context_video_timesteps], dim=1)
# Conditioning by replacing the video latents
total_ref_latents = ref_frames_latents if ref_frames_latents is not None else []
total_ref_positions = ref_frames_positions if ref_frames_positions is not None else []
total_ref_latents += [in_context_video_latents] if in_context_video_latents is not None else []
total_ref_positions += [in_context_video_positions] if in_context_video_positions is not None else []
if len(total_ref_latents) > 0:
for ref_frames_latent, ref_frames_position in zip(total_ref_latents, total_ref_positions):
ref_frames_latent = video_patchifier.patchify(ref_frames_latent)
ref_frames_timestep = timestep.repeat(1, ref_frames_latent.shape[1], 1) * 0.
video_latents = torch.cat([video_latents, ref_frames_latent], dim=1)
video_positions = torch.cat([video_positions, ref_frames_position], dim=2)
video_timesteps = torch.cat([video_timesteps, ref_frames_timestep], dim=1)
if audio_latents is not None:
_, c_a, _, mel_bins = audio_latents.shape