mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Ltx2.3 a2v& retake video and audio (#1346)
* temp commit * support ltx2 a2v * support ltx2.3 retake video and audio * add news * minor fix
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
|
||||
from fractions import Fraction
|
||||
import torch
|
||||
import torchaudio
|
||||
import av
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
from collections.abc import Generator, Iterator
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
@@ -69,9 +68,9 @@ def _prepare_audio_stream(container: av.container.Container, audio_sample_rate:
|
||||
audio_stream = container.add_stream("aac")
|
||||
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
|
||||
if supported_sample_rates:
|
||||
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
|
||||
if best_rate != audio_sample_rate:
|
||||
print(f"Using closest supported audio sample rate: {best_rate}")
|
||||
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
|
||||
if best_rate != audio_sample_rate:
|
||||
print(f"Using closest supported audio sample rate: {best_rate}")
|
||||
else:
|
||||
best_rate = audio_sample_rate
|
||||
audio_stream.codec_context.sample_rate = best_rate
|
||||
@@ -79,6 +78,7 @@ def _prepare_audio_stream(container: av.container.Container, audio_sample_rate:
|
||||
audio_stream.codec_context.time_base = Fraction(1, best_rate)
|
||||
return audio_stream
|
||||
|
||||
|
||||
def write_video_audio_ltx2(
|
||||
video: list[Image.Image],
|
||||
audio: torch.Tensor | None,
|
||||
@@ -116,7 +116,7 @@ def write_video_audio_ltx2(
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
@@ -137,6 +137,32 @@ def write_video_audio_ltx2(
|
||||
container.close()
|
||||
|
||||
|
||||
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
|
||||
"""Resample waveform to target sample rate if needed."""
|
||||
if source_rate == target_rate:
|
||||
return waveform
|
||||
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
||||
return resampled.to(dtype=waveform.dtype)
|
||||
|
||||
|
||||
def read_audio_with_torchaudio(
|
||||
path: str,
|
||||
start_time: float = 0,
|
||||
duration: float | None = None,
|
||||
resample: bool = False,
|
||||
resample_rate: int = 48000,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
waveform, sample_rate = torchaudio.load(path, channels_first=True)
|
||||
if resample:
|
||||
waveform = resample_waveform(waveform, sample_rate, resample_rate)
|
||||
sample_rate = resample_rate
|
||||
start_frame = int(start_time * sample_rate)
|
||||
if start_frame > waveform.shape[-1]:
|
||||
raise ValueError(f"start_time of {start_time} exceeds max duration of {waveform.shape[-1] / sample_rate:.2f}")
|
||||
end_frame = None if duration is None else int(duration * sample_rate + start_frame)
|
||||
return waveform[..., start_frame:end_frame], sample_rate
|
||||
|
||||
|
||||
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
|
||||
container = av.open(output_file, "w", format="mp4")
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user