* support mova inference

* mova media_io

* add unified audio_video api & fix bug of mono audio input for ltx

* support mova train

* mova docs

* fix bug
This commit is contained in:
Hong Zhang
2026-03-13 13:06:07 +08:00
committed by GitHub
parent 4741542523
commit 681df93a85
37 changed files with 3102 additions and 181 deletions

View File

@@ -0,0 +1,108 @@
import torch
import torchaudio
from torchcodec.decoders import AudioDecoder
from torchcodec.encoders import AudioEncoder
def convert_to_mono(audio_tensor: torch.Tensor) -> torch.Tensor:
"""
Convert audio to mono by averaging channels.
Supports [C, T] or [B, C, T]. Output shape: [1, T] or [B, 1, T].
"""
return audio_tensor.mean(dim=-2, keepdim=True)
def convert_to_stereo(audio_tensor: torch.Tensor) -> torch.Tensor:
"""
Convert audio to stereo.
Supports [C, T] or [B, C, T]. Duplicate mono, keep stereo.
"""
if audio_tensor.size(-2) == 1:
return audio_tensor.repeat(1, 2, 1) if audio_tensor.dim() == 3 else audio_tensor.repeat(2, 1)
return audio_tensor
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
"""Resample waveform to target sample rate if needed."""
if source_rate == target_rate:
return waveform
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
return resampled.to(dtype=waveform.dtype)
def read_audio_with_torchcodec(
path: str,
start_time: float = 0,
duration: float | None = None,
) -> tuple[torch.Tensor, int]:
"""
Read audio from file natively using torchcodec, with optional start time and duration.
Args:
path (str): The file path to the audio file.
start_time (float, optional): The start time in seconds to read from. Defaults to 0.
duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.
Returns:
tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.
The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.
"""
decoder = AudioDecoder(path)
stop_seconds = None if duration is None else start_time + duration
waveform = decoder.get_samples_played_in_range(start_seconds=start_time, stop_seconds=stop_seconds).data
return waveform, decoder.metadata.sample_rate
def read_audio(
path: str,
start_time: float = 0,
duration: float | None = None,
resample: bool = False,
resample_rate: int = 48000,
backend: str = "torchcodec",
) -> tuple[torch.Tensor, int]:
"""
Read audio from file, with optional start time, duration, and resampling.
Args:
path (str): The file path to the audio file.
start_time (float, optional): The start time in seconds to read from. Defaults to 0.
duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.
resample (bool, optional): Whether to resample the audio to a different sample rate. Defaults to False.
resample_rate (int, optional): The target sample rate for resampling if resample is True. Defaults to 48000.
backend (str, optional): The audio backend to use for reading. Defaults to "torchcodec".
Returns:
tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.
The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.
"""
if backend == "torchcodec":
waveform, sample_rate = read_audio_with_torchcodec(path, start_time, duration)
else:
raise ValueError(f"Unsupported audio backend: {backend}")
if resample:
waveform = resample_waveform(waveform, sample_rate, resample_rate)
sample_rate = resample_rate
return waveform, sample_rate
def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend: str = "torchcodec"):
"""
Save audio tensor to file.
Args:
waveform (torch.Tensor): The audio tensor to save. Shape can be [C, T] or [B, C, T].
sample_rate (int): The sample rate of the audio.
save_path (str): The file path to save the audio to.
backend (str, optional): The audio backend to use for saving. Defaults to "torchcodec".
"""
if waveform.dim() == 3:
waveform = waveform[0]
if backend == "torchcodec":
encoder = AudioEncoder(waveform, sample_rate=sample_rate)
encoder.to_file(dest=save_path)
else:
raise ValueError(f"Unsupported audio backend: {backend}")

View File

@@ -0,0 +1,134 @@
import av
from fractions import Fraction
import torch
from PIL import Image
from tqdm import tqdm
from .audio import convert_to_stereo
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
) -> None:
if samples.ndim == 1:
samples = samples.unsqueeze(0)
samples = convert_to_stereo(samples)
assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
samples = samples.T
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac")
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
if supported_sample_rates:
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
if best_rate != audio_sample_rate:
print(f"Using closest supported audio sample rate: {best_rate}")
else:
best_rate = audio_sample_rate
audio_stream.codec_context.sample_rate = best_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, best_rate)
return audio_stream
def write_video_audio(
video: list[Image.Image],
audio: torch.Tensor | None,
output_path: str,
fps: int = 24,
audio_sample_rate: int | None = None,
) -> None:
"""
Writes a sequence of images and an audio tensor to a video file.
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
and multiplex a PyTorch tensor as the audio stream into the output container.
Args:
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
The length of this list determines the total duration of the video based on the FPS.
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
The shape is typically (channels, samples). If no audio is required, pass None.
channels can be 1 or 2. 1 for mono, 2 for stereo.
output_path (str): The file path (including extension) where the output video will be saved.
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
If the audio tensor is provided and this is None, the function attempts to infer the rate
based on the audio tensor's length and the video duration.
Raises:
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
"""
duration = len(video) / fps
if audio_sample_rate is None:
audio_sample_rate = int(audio.shape[-1] / duration)
width, height = video[0].size
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame in tqdm(video, total=len(video)):
frame = av.VideoFrame.from_image(frame)
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()

View File

@@ -1,166 +1,7 @@
from fractions import Fraction
import torch
import torchaudio
import av
from tqdm import tqdm
from PIL import Image
import numpy as np
from io import BytesIO
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
) -> None:
if samples.ndim == 1:
samples = samples[:, None]
if samples.shape[0] == 1:
samples = samples.repeat(2, 1)
assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
samples = samples.T
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac")
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
if supported_sample_rates:
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
if best_rate != audio_sample_rate:
print(f"Using closest supported audio sample rate: {best_rate}")
else:
best_rate = audio_sample_rate
audio_stream.codec_context.sample_rate = best_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, best_rate)
return audio_stream
def write_video_audio_ltx2(
video: list[Image.Image],
audio: torch.Tensor | None,
output_path: str,
fps: int = 24,
audio_sample_rate: int | None = None,
) -> None:
"""
Writes a sequence of images and an audio tensor to a video file.
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
and multiplex a PyTorch tensor as the audio stream into the output container.
Args:
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
The length of this list determines the total duration of the video based on the FPS.
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
The shape is typically (channels, samples). If no audio is required, pass None.
channels can be 1 or 2. 1 for mono, 2 for stereo.
output_path (str): The file path (including extension) where the output video will be saved.
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
If the audio tensor is provided and this is None, the function attempts to infer the rate
based on the audio tensor's length and the video duration.
Raises:
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
"""
duration = len(video) / fps
if audio_sample_rate is None:
audio_sample_rate = int(audio.shape[-1] / duration)
width, height = video[0].size
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame in tqdm(video, total=len(video)):
frame = av.VideoFrame.from_image(frame)
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
"""Resample waveform to target sample rate if needed."""
if source_rate == target_rate:
return waveform
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
return resampled.to(dtype=waveform.dtype)
def read_audio_with_torchaudio(
path: str,
start_time: float = 0,
duration: float | None = None,
resample: bool = False,
resample_rate: int = 48000,
) -> tuple[torch.Tensor, int]:
waveform, sample_rate = torchaudio.load(path, channels_first=True)
if resample:
waveform = resample_waveform(waveform, sample_rate, resample_rate)
sample_rate = resample_rate
start_frame = int(start_time * sample_rate)
if start_frame > waveform.shape[-1]:
raise ValueError(f"start_time of {start_time} exceeds max duration of {waveform.shape[-1] / sample_rate:.2f}")
end_frame = None if duration is None else int(duration * sample_rate + start_frame)
return waveform[..., start_frame:end_frame], sample_rate
from .audio_video import write_video_audio as write_video_audio_ltx2
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None: