Ltx2.3 a2v& retake video and audio (#1346)

* temp commit

* support ltx2 a2v

* support ltx2.3 retake video and audio

* add news

* minor fix
This commit is contained in:
Hong Zhang
2026-03-12 14:16:01 +08:00
committed by GitHub
parent c927062546
commit 4741542523
11 changed files with 453 additions and 30 deletions

View File

@@ -1,12 +1,11 @@
from fractions import Fraction
import torch
import torchaudio
import av
from tqdm import tqdm
from PIL import Image
import numpy as np
from io import BytesIO
from collections.abc import Generator, Iterator
def _resample_audio(
@@ -69,9 +68,9 @@ def _prepare_audio_stream(container: av.container.Container, audio_sample_rate:
audio_stream = container.add_stream("aac")
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
if supported_sample_rates:
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
if best_rate != audio_sample_rate:
print(f"Using closest supported audio sample rate: {best_rate}")
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
if best_rate != audio_sample_rate:
print(f"Using closest supported audio sample rate: {best_rate}")
else:
best_rate = audio_sample_rate
audio_stream.codec_context.sample_rate = best_rate
@@ -79,6 +78,7 @@ def _prepare_audio_stream(container: av.container.Container, audio_sample_rate:
audio_stream.codec_context.time_base = Fraction(1, best_rate)
return audio_stream
def write_video_audio_ltx2(
video: list[Image.Image],
audio: torch.Tensor | None,
@@ -116,7 +116,7 @@ def write_video_audio_ltx2(
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
@@ -137,6 +137,32 @@ def write_video_audio_ltx2(
container.close()
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
"""Resample waveform to target sample rate if needed."""
if source_rate == target_rate:
return waveform
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
return resampled.to(dtype=waveform.dtype)
def read_audio_with_torchaudio(
path: str,
start_time: float = 0,
duration: float | None = None,
resample: bool = False,
resample_rate: int = 48000,
) -> tuple[torch.Tensor, int]:
waveform, sample_rate = torchaudio.load(path, channels_first=True)
if resample:
waveform = resample_waveform(waveform, sample_rate, resample_rate)
sample_rate = resample_rate
start_frame = int(start_time * sample_rate)
if start_frame > waveform.shape[-1]:
raise ValueError(f"start_time of {start_time} exceeds max duration of {waveform.shape[-1] / sample_rate:.2f}")
end_frame = None if duration is None else int(duration * sample_rate + start_frame)
return waveform[..., start_frame:end_frame], sample_rate
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
container = av.open(output_file, "w", format="mp4")
try: