Merge pull request #1343 from mi804/ltx2.3_multiref

Ltx2.3 multiref
This commit is contained in:
Hong Zhang
2026-03-10 17:31:05 +08:00
committed by GitHub
parent f3ebd6f714
commit c927062546
4 changed files with 113 additions and 36 deletions

View File

@@ -147,6 +147,19 @@ class BasePipeline(torch.nn.Module):
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
return video return video
def output_audio_format_check(self, audio_output):
# output standard foramt: [C, T], output dtype: float()
# remove batch dim
if audio_output.ndim == 3:
audio_output = audio_output.squeeze(0)
# Transform to stereo
if audio_output.shape[0] == 1:
audio_output = audio_output.repeat(2, 1)
elif audio_output.shape[0] == 2:
pass
else:
raise ValueError("The output audio should be [C, T] or [1, C, T] or [2, C, T].")
return audio_output.float()
def load_models_to_device(self, model_names): def load_models_to_device(self, model_names):
if self.vram_management_enabled: if self.vram_management_enabled:

View File

@@ -141,7 +141,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
pipe.vram_management_enabled = pipe.check_vram_management_state() pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe return pipe
def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False): def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False):
if skip_stage: if skip_stage:
return inputs_shared, inputs_posi, inputs_nega return inputs_shared, inputs_posi, inputs_nega
@@ -160,7 +159,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared) inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
return inputs_shared, inputs_posi, inputs_nega return inputs_shared, inputs_posi, inputs_nega
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
@@ -231,7 +229,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
video = self.vae_output_to_video(video) video = self.vae_output_to_video(video)
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder']) self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"]) decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float() decoded_audio = self.audio_vocoder(decoded_audio)
decoded_audio = self.output_audio_format_check(decoded_audio)
return video, decoded_audio return video, decoded_audio
@@ -394,8 +393,8 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"), input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "frame_rate", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"),
output_params=("denoise_mask_video", "input_latents_video"), output_params=("denoise_mask_video", "input_latents_video", "ref_frames_latents", "ref_frames_positions"),
onload_model_names=("video_vae_encoder") onload_model_names=("video_vae_encoder")
) )
@@ -418,20 +417,30 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
return initial_latents, denoise_mask return initial_latents, denoise_mask
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None): def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, frame_rate, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None):
if input_images is None or len(input_images) == 0: if input_images is None or len(input_images) == 0:
return {} return {}
else: else:
if len(input_images_indexes) != len(set(input_images_indexes)): if len(input_images_indexes) != len(set(input_images_indexes)):
raise ValueError("Input images must have unique indexes.") raise ValueError("Input images must have unique indexes.")
pipe.load_models_to_device(self.onload_model_names) pipe.load_models_to_device(self.onload_model_names)
frame_conditions = {} frame_conditions = {"input_latents_video": None, "denoise_mask_video": None, "ref_frames_latents": [], "ref_frames_positions": []}
for img, index in zip(input_images, input_images_indexes): for img, index in zip(input_images, input_images_indexes):
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels) latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
# first_frame # first_frame by replacing latents
if index == 0: if index == 0:
input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength, initial_latents) input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength, initial_latents)
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video}) frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
# other frames by adding reference latents
else:
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(latents.shape), device=pipe.device)
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, False).float()
video_positions[:, 0, ...] = (video_positions[:, 0, ...] + index) / frame_rate
video_positions = video_positions.to(pipe.torch_dtype)
frame_conditions["ref_frames_latents"].append(latents)
frame_conditions["ref_frames_positions"].append(video_positions)
if len(frame_conditions["ref_frames_latents"]) == 0:
frame_conditions.update({"ref_frames_latents": None, "ref_frames_positions": None})
return frame_conditions return frame_conditions
@@ -553,6 +562,9 @@ def model_fn_ltx2(
# First Frame Conditioning # First Frame Conditioning
input_latents_video=None, input_latents_video=None,
denoise_mask_video=None, denoise_mask_video=None,
# Other Frames Conditioning
ref_frames_latents=None,
ref_frames_positions=None,
# In-Context Conditioning # In-Context Conditioning
in_context_video_latents=None, in_context_video_latents=None,
in_context_video_positions=None, in_context_video_positions=None,
@@ -568,17 +580,24 @@ def model_fn_ltx2(
video_latents = video_patchifier.patchify(video_latents) video_latents = video_patchifier.patchify(video_latents)
seq_len_video = video_latents.shape[1] seq_len_video = video_latents.shape[1]
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1) video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
if denoise_mask_video is not None: # Frist frame conditioning by replacing the video latents
if input_latents_video is not None:
denoise_mask_video = video_patchifier.patchify(denoise_mask_video) denoise_mask_video = video_patchifier.patchify(denoise_mask_video)
video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video) video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)
video_timesteps = denoise_mask_video * video_timesteps video_timesteps = denoise_mask_video * video_timesteps
if in_context_video_latents is not None: # Conditioning by replacing the video latents
in_context_video_latents = video_patchifier.patchify(in_context_video_latents) total_ref_latents = ref_frames_latents if ref_frames_latents is not None else []
in_context_video_timesteps = timestep.repeat(1, in_context_video_latents.shape[1], 1) * 0. total_ref_positions = ref_frames_positions if ref_frames_positions is not None else []
video_latents = torch.cat([video_latents, in_context_video_latents], dim=1) total_ref_latents += [in_context_video_latents] if in_context_video_latents is not None else []
video_positions = torch.cat([video_positions, in_context_video_positions], dim=2) total_ref_positions += [in_context_video_positions] if in_context_video_positions is not None else []
video_timesteps = torch.cat([video_timesteps, in_context_video_timesteps], dim=1) if len(total_ref_latents) > 0:
for ref_frames_latent, ref_frames_position in zip(total_ref_latents, total_ref_positions):
ref_frames_latent = video_patchifier.patchify(ref_frames_latent)
ref_frames_timestep = timestep.repeat(1, ref_frames_latent.shape[1], 1) * 0.
video_latents = torch.cat([video_latents, ref_frames_latent], dim=1)
video_positions = torch.cat([video_positions, ref_frames_position], dim=2)
video_timesteps = torch.cat([video_timesteps, ref_frames_timestep], dim=1)
if audio_latents is not None: if audio_latents is not None:
_, c_a, _, mel_bins = audio_latents.shape _, c_a, _, mel_bins = audio_latents.shape

View File

@@ -43,13 +43,10 @@ def _write_audio(
) -> None: ) -> None:
if samples.ndim == 1: if samples.ndim == 1:
samples = samples[:, None] samples = samples[:, None]
if samples.shape[0] == 1:
if samples.shape[1] != 2 and samples.shape[0] == 2: samples = samples.repeat(2, 1)
samples = samples.T assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
samples = samples.T
if samples.shape[1] != 2:
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
# Convert to int16 packed for ingestion; resampler converts to encoder fmt. # Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16: if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0) samples = torch.clip(samples, -1.0, 1.0)
@@ -69,10 +66,17 @@ def _prepare_audio_stream(container: av.container.Container, audio_sample_rate:
""" """
Prepare the audio stream for writing. Prepare the audio stream for writing.
""" """
audio_stream = container.add_stream("aac", rate=audio_sample_rate) audio_stream = container.add_stream("aac")
audio_stream.codec_context.sample_rate = audio_sample_rate supported_sample_rates = audio_stream.codec_context.codec.audio_rates
if supported_sample_rates:
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
if best_rate != audio_sample_rate:
print(f"Using closest supported audio sample rate: {best_rate}")
else:
best_rate = audio_sample_rate
audio_stream.codec_context.sample_rate = best_rate
audio_stream.codec_context.layout = "stereo" audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) audio_stream.codec_context.time_base = Fraction(1, best_rate)
return audio_stream return audio_stream
def write_video_audio_ltx2( def write_video_audio_ltx2(
@@ -80,8 +84,31 @@ def write_video_audio_ltx2(
audio: torch.Tensor | None, audio: torch.Tensor | None,
output_path: str, output_path: str,
fps: int = 24, fps: int = 24,
audio_sample_rate: int | None = 24000, audio_sample_rate: int | None = None,
) -> None: ) -> None:
"""
Writes a sequence of images and an audio tensor to a video file.
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
and multiplex a PyTorch tensor as the audio stream into the output container.
Args:
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
The length of this list determines the total duration of the video based on the FPS.
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
The shape is typically (channels, samples). If no audio is required, pass None.
channels can be 1 or 2. 1 for mono, 2 for stereo.
output_path (str): The file path (including extension) where the output video will be saved.
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
If the audio tensor is provided and this is None, the function attempts to infer the rate
based on the audio tensor's length and the video duration.
Raises:
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
"""
duration = len(video) / fps
if audio_sample_rate is None:
audio_sample_rate = int(audio.shape[-1] / duration)
width, height = video[0].size width, height = video[0].size
container = av.open(output_path, mode="w") container = av.open(output_path, mode="w")

View File

@@ -26,7 +26,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"), stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
) )
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" prompt = "Two cute orange cats, wearing boxing gloves, stand in a boxing ring and fight each other. They are punching each other fast and yelling: 'I will win!'"
negative_prompt = ( negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
@@ -41,12 +41,9 @@ negative_prompt = (
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
) )
height, width, num_frames = 512 * 2, 768 * 2, 121 height, width, num_frames = 512 * 2, 768 * 2, 121
dataset_snapshot_download( dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset")
dataset_id="DiffSynth-Studio/examples_in_diffsynth", first_frame = Image.open("data/example_video_dataset/ltx2/first_frame.png").convert("RGB").resize((width, height))
local_dir="./", last_frame = Image.open("data/example_video_dataset/ltx2/last_frame.png").convert("RGB").resize((width, height))
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
)
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
# first frame # first frame
video, audio = pipe( video, audio = pipe(
prompt=prompt, prompt=prompt,
@@ -57,7 +54,7 @@ video, audio = pipe(
num_frames=num_frames, num_frames=num_frames,
tiled=True, tiled=True,
use_two_stage_pipeline=True, use_two_stage_pipeline=True,
input_images=[image], input_images=[first_frame],
input_images_indexes=[0], input_images_indexes=[0],
input_images_strength=1.0, input_images_strength=1.0,
) )
@@ -66,5 +63,26 @@ write_video_audio_ltx2(
audio=audio, audio=audio,
output_path='ltx2.3_twostage_i2av_first.mp4', output_path='ltx2.3_twostage_i2av_first.mp4',
fps=24, fps=24,
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate, )
pipe.clear_lora()
# This example uses the first and last frames for demonstration. However, you can use any frames by setting input_images and input_indexes. Note that input_indexes must be within the range of num_frames.
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=42,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_two_stage_pipeline=True,
input_images=[first_frame, last_frame],
input_images_indexes=[0, num_frames-1],
input_images_strength=1.0,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2.3_twostage_i2av_first_last.mp4',
fps=24,
) )