mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Mova (#1337)
* support mova inference * mova media_io * add unified audio_video api & fix bug of mono audio input for ltx * support mova train * mova docs * fix bug
This commit is contained in:
108
diffsynth/utils/data/audio.py
Normal file
108
diffsynth/utils/data/audio.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from torchcodec.decoders import AudioDecoder
|
||||
from torchcodec.encoders import AudioEncoder
|
||||
|
||||
|
||||
def convert_to_mono(audio_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert audio to mono by averaging channels.
|
||||
Supports [C, T] or [B, C, T]. Output shape: [1, T] or [B, 1, T].
|
||||
"""
|
||||
return audio_tensor.mean(dim=-2, keepdim=True)
|
||||
|
||||
|
||||
def convert_to_stereo(audio_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert audio to stereo.
|
||||
Supports [C, T] or [B, C, T]. Duplicate mono, keep stereo.
|
||||
"""
|
||||
if audio_tensor.size(-2) == 1:
|
||||
return audio_tensor.repeat(1, 2, 1) if audio_tensor.dim() == 3 else audio_tensor.repeat(2, 1)
|
||||
return audio_tensor
|
||||
|
||||
|
||||
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
|
||||
"""Resample waveform to target sample rate if needed."""
|
||||
if source_rate == target_rate:
|
||||
return waveform
|
||||
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
||||
return resampled.to(dtype=waveform.dtype)
|
||||
|
||||
|
||||
def read_audio_with_torchcodec(
|
||||
path: str,
|
||||
start_time: float = 0,
|
||||
duration: float | None = None,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""
|
||||
Read audio from file natively using torchcodec, with optional start time and duration.
|
||||
|
||||
Args:
|
||||
path (str): The file path to the audio file.
|
||||
start_time (float, optional): The start time in seconds to read from. Defaults to 0.
|
||||
duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.
|
||||
The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.
|
||||
"""
|
||||
decoder = AudioDecoder(path)
|
||||
stop_seconds = None if duration is None else start_time + duration
|
||||
waveform = decoder.get_samples_played_in_range(start_seconds=start_time, stop_seconds=stop_seconds).data
|
||||
return waveform, decoder.metadata.sample_rate
|
||||
|
||||
|
||||
def read_audio(
|
||||
path: str,
|
||||
start_time: float = 0,
|
||||
duration: float | None = None,
|
||||
resample: bool = False,
|
||||
resample_rate: int = 48000,
|
||||
backend: str = "torchcodec",
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""
|
||||
Read audio from file, with optional start time, duration, and resampling.
|
||||
|
||||
Args:
|
||||
path (str): The file path to the audio file.
|
||||
start_time (float, optional): The start time in seconds to read from. Defaults to 0.
|
||||
duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.
|
||||
resample (bool, optional): Whether to resample the audio to a different sample rate. Defaults to False.
|
||||
resample_rate (int, optional): The target sample rate for resampling if resample is True. Defaults to 48000.
|
||||
backend (str, optional): The audio backend to use for reading. Defaults to "torchcodec".
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.
|
||||
The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.
|
||||
"""
|
||||
if backend == "torchcodec":
|
||||
waveform, sample_rate = read_audio_with_torchcodec(path, start_time, duration)
|
||||
else:
|
||||
raise ValueError(f"Unsupported audio backend: {backend}")
|
||||
|
||||
if resample:
|
||||
waveform = resample_waveform(waveform, sample_rate, resample_rate)
|
||||
sample_rate = resample_rate
|
||||
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend: str = "torchcodec"):
|
||||
"""
|
||||
Save audio tensor to file.
|
||||
|
||||
Args:
|
||||
waveform (torch.Tensor): The audio tensor to save. Shape can be [C, T] or [B, C, T].
|
||||
sample_rate (int): The sample rate of the audio.
|
||||
save_path (str): The file path to save the audio to.
|
||||
backend (str, optional): The audio backend to use for saving. Defaults to "torchcodec".
|
||||
"""
|
||||
if waveform.dim() == 3:
|
||||
waveform = waveform[0]
|
||||
|
||||
if backend == "torchcodec":
|
||||
encoder = AudioEncoder(waveform, sample_rate=sample_rate)
|
||||
encoder.to_file(dest=save_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported audio backend: {backend}")
|
||||
134
diffsynth/utils/data/audio_video.py
Normal file
134
diffsynth/utils/data/audio_video.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import av
|
||||
from fractions import Fraction
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from .audio import convert_to_stereo
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||
) -> None:
|
||||
cc = audio_stream.codec_context
|
||||
|
||||
# Use the encoder's format/layout/rate as the *target*
|
||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||
target_layout = cc.layout or "stereo"
|
||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=target_format,
|
||||
layout=target_layout,
|
||||
rate=target_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = frame_in.sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def _write_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples.unsqueeze(0)
|
||||
samples = convert_to_stereo(samples)
|
||||
assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
|
||||
samples = samples.T
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
samples = (samples * 32767.0).to(torch.int16)
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(
|
||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
)
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
_resample_audio(container, audio_stream, frame_in)
|
||||
|
||||
|
||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac")
|
||||
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
|
||||
if supported_sample_rates:
|
||||
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
|
||||
if best_rate != audio_sample_rate:
|
||||
print(f"Using closest supported audio sample rate: {best_rate}")
|
||||
else:
|
||||
best_rate = audio_sample_rate
|
||||
audio_stream.codec_context.sample_rate = best_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, best_rate)
|
||||
return audio_stream
|
||||
|
||||
|
||||
def write_video_audio(
|
||||
video: list[Image.Image],
|
||||
audio: torch.Tensor | None,
|
||||
output_path: str,
|
||||
fps: int = 24,
|
||||
audio_sample_rate: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Writes a sequence of images and an audio tensor to a video file.
|
||||
|
||||
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
|
||||
and multiplex a PyTorch tensor as the audio stream into the output container.
|
||||
|
||||
Args:
|
||||
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
|
||||
The length of this list determines the total duration of the video based on the FPS.
|
||||
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
|
||||
The shape is typically (channels, samples). If no audio is required, pass None.
|
||||
channels can be 1 or 2. 1 for mono, 2 for stereo.
|
||||
output_path (str): The file path (including extension) where the output video will be saved.
|
||||
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
|
||||
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
|
||||
If the audio tensor is provided and this is None, the function attempts to infer the rate
|
||||
based on the audio tensor's length and the video duration.
|
||||
Raises:
|
||||
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
|
||||
"""
|
||||
duration = len(video) / fps
|
||||
if audio_sample_rate is None:
|
||||
audio_sample_rate = int(audio.shape[-1] / duration)
|
||||
|
||||
width, height = video[0].size
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame in tqdm(video, total=len(video)):
|
||||
frame = av.VideoFrame.from_image(frame)
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
if audio is not None:
|
||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||
|
||||
container.close()
|
||||
@@ -1,166 +1,7 @@
|
||||
from fractions import Fraction
|
||||
import torch
|
||||
import torchaudio
|
||||
import av
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||
) -> None:
|
||||
cc = audio_stream.codec_context
|
||||
|
||||
# Use the encoder's format/layout/rate as the *target*
|
||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||
target_layout = cc.layout or "stereo"
|
||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=target_format,
|
||||
layout=target_layout,
|
||||
rate=target_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = frame_in.sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def _write_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples[:, None]
|
||||
if samples.shape[0] == 1:
|
||||
samples = samples.repeat(2, 1)
|
||||
assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
|
||||
samples = samples.T
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
samples = (samples * 32767.0).to(torch.int16)
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(
|
||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
)
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
_resample_audio(container, audio_stream, frame_in)
|
||||
|
||||
|
||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac")
|
||||
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
|
||||
if supported_sample_rates:
|
||||
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
|
||||
if best_rate != audio_sample_rate:
|
||||
print(f"Using closest supported audio sample rate: {best_rate}")
|
||||
else:
|
||||
best_rate = audio_sample_rate
|
||||
audio_stream.codec_context.sample_rate = best_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, best_rate)
|
||||
return audio_stream
|
||||
|
||||
|
||||
def write_video_audio_ltx2(
|
||||
video: list[Image.Image],
|
||||
audio: torch.Tensor | None,
|
||||
output_path: str,
|
||||
fps: int = 24,
|
||||
audio_sample_rate: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Writes a sequence of images and an audio tensor to a video file.
|
||||
|
||||
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
|
||||
and multiplex a PyTorch tensor as the audio stream into the output container.
|
||||
|
||||
Args:
|
||||
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
|
||||
The length of this list determines the total duration of the video based on the FPS.
|
||||
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
|
||||
The shape is typically (channels, samples). If no audio is required, pass None.
|
||||
channels can be 1 or 2. 1 for mono, 2 for stereo.
|
||||
output_path (str): The file path (including extension) where the output video will be saved.
|
||||
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
|
||||
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
|
||||
If the audio tensor is provided and this is None, the function attempts to infer the rate
|
||||
based on the audio tensor's length and the video duration.
|
||||
Raises:
|
||||
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
|
||||
"""
|
||||
duration = len(video) / fps
|
||||
if audio_sample_rate is None:
|
||||
audio_sample_rate = int(audio.shape[-1] / duration)
|
||||
|
||||
width, height = video[0].size
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame in tqdm(video, total=len(video)):
|
||||
frame = av.VideoFrame.from_image(frame)
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
if audio is not None:
|
||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||
|
||||
container.close()
|
||||
|
||||
|
||||
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
|
||||
"""Resample waveform to target sample rate if needed."""
|
||||
if source_rate == target_rate:
|
||||
return waveform
|
||||
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
||||
return resampled.to(dtype=waveform.dtype)
|
||||
|
||||
|
||||
def read_audio_with_torchaudio(
|
||||
path: str,
|
||||
start_time: float = 0,
|
||||
duration: float | None = None,
|
||||
resample: bool = False,
|
||||
resample_rate: int = 48000,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
waveform, sample_rate = torchaudio.load(path, channels_first=True)
|
||||
if resample:
|
||||
waveform = resample_waveform(waveform, sample_rate, resample_rate)
|
||||
sample_rate = resample_rate
|
||||
start_frame = int(start_time * sample_rate)
|
||||
if start_frame > waveform.shape[-1]:
|
||||
raise ValueError(f"start_time of {start_time} exceeds max duration of {waveform.shape[-1] / sample_rate:.2f}")
|
||||
end_frame = None if duration is None else int(duration * sample_rate + start_frame)
|
||||
return waveform[..., start_frame:end_frame], sample_rate
|
||||
from .audio_video import write_video_audio as write_video_audio_ltx2
|
||||
|
||||
|
||||
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp
|
||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks
|
||||
|
||||
@@ -143,4 +143,31 @@ def usp_attn_forward(self, x, freqs):
|
||||
|
||||
del q, k, v
|
||||
getattr(torch, parse_device_type(x.device)).empty_cache()
|
||||
return self.o(x)
|
||||
return self.o(x)
|
||||
|
||||
|
||||
def get_current_chunk(x, dim=1):
|
||||
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=dim)
|
||||
ndims = len(chunks[0].shape)
|
||||
pad_list = [0] * (2 * ndims)
|
||||
pad_end_index = 2 * (ndims - 1 - dim) + 1
|
||||
max_size = chunks[0].size(dim)
|
||||
chunks = [
|
||||
torch.nn.functional.pad(
|
||||
chunk,
|
||||
tuple(pad_list[:pad_end_index] + [max_size - chunk.size(dim)] + pad_list[pad_end_index+1:]),
|
||||
value=0
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
x = chunks[get_sequence_parallel_rank()]
|
||||
return x
|
||||
|
||||
|
||||
def gather_all_chunks(x, seq_len=None, dim=1):
|
||||
x = get_sp_group().all_gather(x, dim=dim)
|
||||
if seq_len is not None:
|
||||
slices = [slice(None)] * x.ndim
|
||||
slices[dim] = slice(0, seq_len)
|
||||
x = x[tuple(slices)]
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user