from fractions import Fraction import torch import av from tqdm import tqdm from PIL import Image def _resample_audio( container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame ) -> None: cc = audio_stream.codec_context # Use the encoder's format/layout/rate as the *target* target_format = cc.format or "fltp" # AAC → usually fltp target_layout = cc.layout or "stereo" target_rate = cc.sample_rate or frame_in.sample_rate audio_resampler = av.audio.resampler.AudioResampler( format=target_format, layout=target_layout, rate=target_rate, ) audio_next_pts = 0 for rframe in audio_resampler.resample(frame_in): if rframe.pts is None: rframe.pts = audio_next_pts audio_next_pts += rframe.samples rframe.sample_rate = frame_in.sample_rate container.mux(audio_stream.encode(rframe)) # flush audio encoder for packet in audio_stream.encode(): container.mux(packet) def _write_audio( container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int ) -> None: if samples.ndim == 1: samples = samples[:, None] if samples.shape[1] != 2 and samples.shape[0] == 2: samples = samples.T if samples.shape[1] != 2: raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") # Convert to int16 packed for ingestion; resampler converts to encoder fmt. if samples.dtype != torch.int16: samples = torch.clip(samples, -1.0, 1.0) samples = (samples * 32767.0).to(torch.int16) frame_in = av.AudioFrame.from_ndarray( samples.contiguous().reshape(1, -1).cpu().numpy(), format="s16", layout="stereo", ) frame_in.sample_rate = audio_sample_rate _resample_audio(container, audio_stream, frame_in) def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: """ Prepare the audio stream for writing. """ audio_stream = container.add_stream("aac", rate=audio_sample_rate) audio_stream.codec_context.sample_rate = audio_sample_rate audio_stream.codec_context.layout = "stereo" audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) return audio_stream def write_video_audio_ltx2( video: list[Image.Image], audio: torch.Tensor | None, output_path: str, fps: int = 24, audio_sample_rate: int | None = 24000, ) -> None: width, height = video[0].size container = av.open(output_path, mode="w") stream = container.add_stream("libx264", rate=int(fps)) stream.width = width stream.height = height stream.pix_fmt = "yuv420p" if audio is not None: if audio_sample_rate is None: raise ValueError("audio_sample_rate is required when audio is provided") audio_stream = _prepare_audio_stream(container, audio_sample_rate) for frame in tqdm(video, total=len(video)): frame = av.VideoFrame.from_image(frame) for packet in stream.encode(frame): container.mux(packet) # Flush encoder for packet in stream.encode(): container.mux(packet) if audio is not None: _write_audio(container, audio_stream, audio, audio_sample_rate) container.close()