Merge pull request #1343 from mi804/ltx2.3_multiref

Ltx2.3 multiref
This commit is contained in:
Hong Zhang
2026-03-10 17:31:05 +08:00
committed by GitHub
parent f3ebd6f714
commit c927062546
4 changed files with 113 additions and 36 deletions

View File

@@ -43,13 +43,10 @@ def _write_audio(
) -> None:
if samples.ndim == 1:
samples = samples[:, None]
if samples.shape[1] != 2 and samples.shape[0] == 2:
samples = samples.T
if samples.shape[1] != 2:
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
if samples.shape[0] == 1:
samples = samples.repeat(2, 1)
assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
samples = samples.T
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
@@ -69,10 +66,17 @@ def _prepare_audio_stream(container: av.container.Container, audio_sample_rate:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
audio_stream.codec_context.sample_rate = audio_sample_rate
audio_stream = container.add_stream("aac")
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
if supported_sample_rates:
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
if best_rate != audio_sample_rate:
print(f"Using closest supported audio sample rate: {best_rate}")
else:
best_rate = audio_sample_rate
audio_stream.codec_context.sample_rate = best_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
audio_stream.codec_context.time_base = Fraction(1, best_rate)
return audio_stream
def write_video_audio_ltx2(
@@ -80,8 +84,31 @@ def write_video_audio_ltx2(
audio: torch.Tensor | None,
output_path: str,
fps: int = 24,
audio_sample_rate: int | None = 24000,
audio_sample_rate: int | None = None,
) -> None:
"""
Writes a sequence of images and an audio tensor to a video file.
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
and multiplex a PyTorch tensor as the audio stream into the output container.
Args:
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
The length of this list determines the total duration of the video based on the FPS.
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
The shape is typically (channels, samples). If no audio is required, pass None.
channels can be 1 or 2. 1 for mono, 2 for stereo.
output_path (str): The file path (including extension) where the output video will be saved.
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
If the audio tensor is provided and this is None, the function attempts to infer the rate
based on the audio tensor's length and the video duration.
Raises:
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
"""
duration = len(video) / fps
if audio_sample_rate is None:
audio_sample_rate = int(audio.shape[-1] / duration)
width, height = video[0].size
container = av.open(output_path, mode="w")