mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
support ltx2.3 inference
This commit is contained in:
@@ -1279,9 +1279,268 @@ class LTX2AudioDecoder(torch.nn.Module):
|
||||
return torch.tanh(h) if self.tanh_out else h
|
||||
|
||||
|
||||
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
||||
# Adopted from https://github.com/NVIDIA/BigVGAN
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sinc(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(
|
||||
x == 0,
|
||||
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x,
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
delta_f = 4 * half_width
|
||||
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if amplitude > 50.0:
|
||||
beta = 0.1102 * (amplitude - 8.7)
|
||||
elif amplitude >= 21.0:
|
||||
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||
filter_ /= filter_.sum()
|
||||
return filter_.view(1, 1, kernel_size)
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cutoff: float = 0.5,
|
||||
half_width: float = 0.6,
|
||||
stride: int = 1,
|
||||
padding: bool = True,
|
||||
padding_mode: str = "replicate",
|
||||
kernel_size: int = 12,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if cutoff < -0.0:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = kernel_size % 2 == 0
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, n_channels, _ = x.shape
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ratio: int = 2,
|
||||
kernel_size: int | None = None,
|
||||
persistent: bool = True,
|
||||
window_type: str = "kaiser",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.stride = ratio
|
||||
|
||||
if window_type == "hann":
|
||||
# Hann-windowed sinc filter equivalent to torchaudio.functional.resample
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
self.kernel_size = 2 * width * ratio + 1
|
||||
self.pad = width
|
||||
self.pad_left = 2 * width * ratio
|
||||
self.pad_right = self.kernel_size - ratio
|
||||
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
|
||||
else:
|
||||
# Kaiser-windowed sinc filter (BigVGAN default).
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
sinc_filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
self.register_buffer("filter", sinc_filter, persistent=persistent)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, n_channels, _ = x.shape
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)
|
||||
x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)
|
||||
return x[..., self.pad_left : -self.pad_right]
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.lowpass(x)
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation: nn.Module,
|
||||
up_ratio: int = 2,
|
||||
down_ratio: int = 2,
|
||||
up_kernel_size: int = 12,
|
||||
down_kernel_size: int = 12,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
return self.downsample(x)
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
alpha: float = 1.0,
|
||||
alpha_trainable: bool = True,
|
||||
alpha_logscale: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
alpha: float = 1.0,
|
||||
alpha_trainable: bool = True,
|
||||
alpha_logscale: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)
|
||||
|
||||
|
||||
class AMPBlock1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: tuple[int, int, int] = (1, 3, 5),
|
||||
activation: str = "snake",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
||||
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
||||
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
||||
]
|
||||
)
|
||||
|
||||
self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])
|
||||
self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
class LTX2Vocoder(torch.nn.Module):
|
||||
"""
|
||||
Vocoder model for synthesizing audio from Mel spectrograms.
|
||||
LTX2Vocoder model for synthesizing audio from Mel spectrograms.
|
||||
Args:
|
||||
resblock_kernel_sizes: List of kernel sizes for the residual blocks.
|
||||
This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
|
||||
@@ -1293,28 +1552,33 @@ class LTX2Vocoder(torch.nn.Module):
|
||||
This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
|
||||
upsample_initial_channel: Initial number of channels for the upsampling layers.
|
||||
This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
|
||||
stereo: Whether to use stereo output.
|
||||
This value is read from the checkpoint at `config.vocoder.stereo`.
|
||||
resblock: Type of residual block to use.
|
||||
resblock: Type of residual block to use ("1", "2", or "AMP1").
|
||||
This value is read from the checkpoint at `config.vocoder.resblock`.
|
||||
output_sample_rate: Waveform sample rate.
|
||||
This value is read from the checkpoint at `config.vocoder.output_sample_rate`.
|
||||
output_sampling_rate: Waveform sample rate.
|
||||
This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.
|
||||
activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1".
|
||||
use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).
|
||||
apply_final_activation: Whether to apply the final tanh/clamp activation.
|
||||
use_bias_at_final: Whether to use bias in the final conv layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
resblock_kernel_sizes: List[int] | None = [3, 7, 11],
|
||||
upsample_rates: List[int] | None = [6, 5, 2, 2, 2],
|
||||
upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4],
|
||||
resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_initial_channel: int = 1024,
|
||||
stereo: bool = True,
|
||||
resblock: str = "1",
|
||||
output_sample_rate: int = 24000,
|
||||
):
|
||||
output_sampling_rate: int = 24000,
|
||||
activation: str = "snake",
|
||||
use_tanh_at_final: bool = True,
|
||||
apply_final_activation: bool = True,
|
||||
use_bias_at_final: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Initialize default values if not provided. Note that mutable default values are not supported.
|
||||
# Mutable default values are not supported as default arguments.
|
||||
if resblock_kernel_sizes is None:
|
||||
resblock_kernel_sizes = [3, 7, 11]
|
||||
if upsample_rates is None:
|
||||
@@ -1324,36 +1588,60 @@ class LTX2Vocoder(torch.nn.Module):
|
||||
if resblock_dilation_sizes is None:
|
||||
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_sampling_rate = output_sampling_rate
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
self.use_tanh_at_final = use_tanh_at_final
|
||||
self.apply_final_activation = apply_final_activation
|
||||
self.is_amp = resblock == "AMP1"
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)):
|
||||
self.ups.append(
|
||||
nn.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
# All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel
|
||||
# bins each), 2 output channels.
|
||||
self.conv_pre = nn.Conv1d(
|
||||
in_channels=128,
|
||||
out_channels=upsample_initial_channel,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
)
|
||||
resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1
|
||||
|
||||
self.ups = nn.ModuleList(
|
||||
nn.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))
|
||||
)
|
||||
|
||||
final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i, _ in enumerate(self.ups):
|
||||
|
||||
for i in range(len(upsample_rates)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
|
||||
self.resblocks.append(resblock_class(ch, kernel_size, dilations))
|
||||
if self.is_amp:
|
||||
self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))
|
||||
else:
|
||||
self.resblocks.append(resblock_cls(ch, kernel_size, dilations))
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
final_channels = upsample_initial_channel // (2**self.num_upsamples)
|
||||
self.conv_post = nn.Conv1d(final_channels, out_channels, 7, 1, padding=3)
|
||||
if self.is_amp:
|
||||
self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))
|
||||
else:
|
||||
self.act_post = nn.LeakyReLU()
|
||||
|
||||
self.upsample_factor = math.prod(layer.stride[0] for layer in self.ups)
|
||||
# All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).
|
||||
self.conv_post = nn.Conv1d(
|
||||
in_channels=final_channels,
|
||||
out_channels=2,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
bias=use_bias_at_final,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@@ -1374,7 +1662,8 @@ class LTX2Vocoder(torch.nn.Module):
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if not self.is_amp:
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
start = i * self.num_kernels
|
||||
end = start + self.num_kernels
|
||||
@@ -1386,23 +1675,198 @@ class LTX2Vocoder(torch.nn.Module):
|
||||
[self.resblocks[idx](x) for idx in range(start, end)],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
x = block_outputs.mean(dim=0)
|
||||
|
||||
x = self.conv_post(F.leaky_relu(x))
|
||||
return torch.tanh(x)
|
||||
x = self.act_post(x)
|
||||
x = self.conv_post(x)
|
||||
|
||||
if self.apply_final_activation:
|
||||
x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def decode_audio(latent: torch.Tensor, audio_decoder: "LTX2AudioDecoder", vocoder: "LTX2Vocoder") -> torch.Tensor:
|
||||
class _STFTFn(nn.Module):
|
||||
"""Implements STFT as a convolution with precomputed DFT x Hann-window bases.
|
||||
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
||||
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
||||
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
||||
bit-identical to what it was trained on.
|
||||
"""
|
||||
Decode an audio latent representation using the provided audio decoder and vocoder.
|
||||
Args:
|
||||
latent: Input audio latent tensor.
|
||||
audio_decoder: Model to decode the latent to waveform features.
|
||||
vocoder: Model to convert decoded features to audio waveform.
|
||||
Returns:
|
||||
Decoded audio as a float tensor.
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
|
||||
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
||||
Applies causal (left-only) padding of win_length - hop_length samples so that
|
||||
each output frame depends only on past and present input — no lookahead.
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
Returns:
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
"""
|
||||
if y.dim() == 2:
|
||||
y = y.unsqueeze(1) # (B, 1, T)
|
||||
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||
y = F.pad(y, (left_pad, 0))
|
||||
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real**2 + imag**2)
|
||||
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
||||
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
||||
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
||||
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
||||
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
||||
"""
|
||||
decoded_audio = audio_decoder(latent)
|
||||
decoded_audio = vocoder(decoded_audio).squeeze(0).float()
|
||||
return decoded_audio
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
n_mel_channels: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
||||
|
||||
# Initialized to zeros; load_state_dict overwrites with the checkpoint's
|
||||
# exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
||||
|
||||
def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
Returns:
|
||||
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
class LTX2VocoderWithBWE(nn.Module):
|
||||
"""LTX2Vocoder with bandwidth extension (BWE) upsampling.
|
||||
Chains a mel-to-wav vocoder with a BWE module that upsamples the output
|
||||
to a higher sample rate. The BWE computes a mel spectrogram from the
|
||||
vocoder output, runs it through a second generator to predict a residual,
|
||||
and adds it to a sinc-resampled skip connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_sampling_rate: int = 16000,
|
||||
output_sampling_rate: int = 48000,
|
||||
hop_length: int = 80,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vocoder = LTX2Vocoder(
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
upsample_rates=[5, 2, 2, 2, 2, 2],
|
||||
upsample_kernel_sizes=[11, 4, 4, 4, 4, 4],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_initial_channel=1536,
|
||||
resblock="AMP1",
|
||||
activation="snakebeta",
|
||||
use_tanh_at_final=False,
|
||||
apply_final_activation=True,
|
||||
use_bias_at_final=False,
|
||||
output_sampling_rate=input_sampling_rate,
|
||||
)
|
||||
self.bwe_generator = LTX2Vocoder(
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
upsample_rates=[6, 5, 2, 2, 2],
|
||||
upsample_kernel_sizes=[12, 11, 4, 4, 4],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_initial_channel=512,
|
||||
resblock="AMP1",
|
||||
activation="snakebeta",
|
||||
use_tanh_at_final=False,
|
||||
apply_final_activation=False,
|
||||
use_bias_at_final=False,
|
||||
output_sampling_rate=output_sampling_rate,
|
||||
)
|
||||
|
||||
self.mel_stft = MelSTFT(
|
||||
filter_length=512,
|
||||
hop_length=hop_length,
|
||||
win_length=512,
|
||||
n_mel_channels=64,
|
||||
)
|
||||
self.input_sampling_rate = input_sampling_rate
|
||||
self.output_sampling_rate = output_sampling_rate
|
||||
self.hop_length = hop_length
|
||||
# Compute the resampler on CPU so the sinc filter is materialized even when
|
||||
# the model is constructed on meta device (SingleGPUModelBuilder pattern).
|
||||
# The filter is not stored in the checkpoint (persistent=False).
|
||||
with torch.device("cpu"):
|
||||
self.resampler = UpSample1d(
|
||||
ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann"
|
||||
)
|
||||
|
||||
@property
|
||||
def conv_pre(self) -> nn.Conv1d:
|
||||
return self.vocoder.conv_pre
|
||||
|
||||
@property
|
||||
def conv_post(self) -> nn.Conv1d:
|
||||
return self.vocoder.conv_post
|
||||
|
||||
def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute log-mel spectrogram from waveform using causal STFT bases.
|
||||
Args:
|
||||
audio: Waveform tensor of shape (B, C, T).
|
||||
Returns:
|
||||
mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).
|
||||
"""
|
||||
batch, n_channels, _ = audio.shape
|
||||
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
|
||||
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
||||
|
||||
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the full vocoder + BWE forward pass.
|
||||
Args:
|
||||
mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo
|
||||
or (B, T, mel_bins) for mono. Same format as LTX2Vocoder.forward.
|
||||
Returns:
|
||||
Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].
|
||||
"""
|
||||
x = self.vocoder(mel_spec)
|
||||
_, _, length_low_rate = x.shape
|
||||
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||
|
||||
# Pad to multiple of hop_length for exact mel frame count
|
||||
remainder = length_low_rate % self.hop_length
|
||||
if remainder != 0:
|
||||
x = F.pad(x, (0, self.hop_length - remainder))
|
||||
|
||||
# Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)
|
||||
mel = self._compute_mel(x)
|
||||
|
||||
# LTX2Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator
|
||||
mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins)
|
||||
residual = self.bwe_generator(mel_for_bwe)
|
||||
skip = self.resampler(x)
|
||||
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
||||
|
||||
return torch.clamp(residual + skip, -1, 1)[..., :output_length]
|
||||
|
||||
Reference in New Issue
Block a user