# Copyright 2025 The ACESTEO Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ACE-Step Audio VAE (AutoencoderOobleck CNN architecture). This is a CNN-based VAE for audio waveform encoding/decoding. It uses weight-normalized convolutions and Snake1d activations. Does NOT depend on diffusers — pure nn.Module implementation. """ import math from typing import Optional import torch import torch.nn as nn from torch.nn.utils import weight_norm, remove_weight_norm class Snake1d(nn.Module): """Snake activation: x + 1/(beta+eps) * sin(alpha*x)^2.""" def __init__(self, hidden_dim: int, logscale: bool = True): super().__init__() self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1)) self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1)) self.logscale = logscale def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shape = hidden_states.shape alpha = torch.exp(self.alpha) if self.logscale else self.alpha beta = torch.exp(self.beta) if self.logscale else self.beta hidden_states = hidden_states.reshape(shape[0], shape[1], -1) hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2) return hidden_states.reshape(shape) class OobleckResidualUnit(nn.Module): """Residual unit: Snake1d → Conv1d(dilated) → Snake1d → Conv1d(1×1) + skip.""" def __init__(self, dimension: int = 16, dilation: int = 1): super().__init__() pad = ((7 - 1) * dilation) // 2 self.snake1 = Snake1d(dimension) self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)) self.snake2 = Snake1d(dimension) self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1)) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: output = self.conv1(self.snake1(hidden_state)) output = self.conv2(self.snake2(output)) padding = (hidden_state.shape[-1] - output.shape[-1]) // 2 if padding > 0: hidden_state = hidden_state[..., padding:-padding] return hidden_state + output class OobleckEncoderBlock(nn.Module): """Encoder block: 3 residual units + downsampling conv.""" def __init__(self, input_dim: int, output_dim: int, stride: int = 1): super().__init__() self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1) self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3) self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9) self.snake1 = Snake1d(input_dim) self.conv1 = weight_norm( nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)) ) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.res_unit1(hidden_state) hidden_state = self.res_unit2(hidden_state) hidden_state = self.snake1(self.res_unit3(hidden_state)) return self.conv1(hidden_state) class OobleckDecoderBlock(nn.Module): """Decoder block: upsampling conv + 3 residual units.""" def __init__(self, input_dim: int, output_dim: int, stride: int = 1): super().__init__() self.snake1 = Snake1d(input_dim) self.conv_t1 = weight_norm( nn.ConvTranspose1d( input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), ) ) self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1) self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3) self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.snake1(hidden_state) hidden_state = self.conv_t1(hidden_state) hidden_state = self.res_unit1(hidden_state) hidden_state = self.res_unit2(hidden_state) return self.res_unit3(hidden_state) class OobleckEncoder(nn.Module): """Full encoder: audio → latent representation [B, encoder_hidden_size, T']. conv1 → [blocks] → snake1 → conv2 """ def __init__( self, encoder_hidden_size: int = 128, audio_channels: int = 2, downsampling_ratios: list = None, channel_multiples: list = None, ): super().__init__() downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10] channel_multiples = channel_multiples or [1, 2, 4, 8, 16] channel_multiples = [1] + channel_multiples self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3)) self.block = nn.ModuleList() for stride_index, stride in enumerate(downsampling_ratios): self.block.append( OobleckEncoderBlock( input_dim=encoder_hidden_size * channel_multiples[stride_index], output_dim=encoder_hidden_size * channel_multiples[stride_index + 1], stride=stride, ) ) d_model = encoder_hidden_size * channel_multiples[-1] self.snake1 = Snake1d(d_model) self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1)) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.conv1(hidden_state) for block in self.block: hidden_state = block(hidden_state) hidden_state = self.snake1(hidden_state) return self.conv2(hidden_state) class OobleckDecoder(nn.Module): """Full decoder: latent → audio waveform [B, audio_channels, T]. conv1 → [blocks] → snake1 → conv2(no bias) """ def __init__( self, channels: int = 128, input_channels: int = 64, audio_channels: int = 2, upsampling_ratios: list = None, channel_multiples: list = None, ): super().__init__() upsampling_ratios = upsampling_ratios or [10, 6, 4, 4, 2] channel_multiples = channel_multiples or [1, 2, 4, 8, 16] channel_multiples = [1] + channel_multiples self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3)) self.block = nn.ModuleList() for stride_index, stride in enumerate(upsampling_ratios): self.block.append( OobleckDecoderBlock( input_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index], output_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index - 1], stride=stride, ) ) self.snake1 = Snake1d(channels) # conv2 has no bias (matches checkpoint: only weight_g/weight_v, no bias key) self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False)) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.conv1(hidden_state) for block in self.block: hidden_state = block(hidden_state) hidden_state = self.snake1(hidden_state) return self.conv2(hidden_state) class OobleckDiagonalGaussianDistribution(object): def __init__(self, parameters: torch.Tensor, deterministic: bool = False): self.parameters = parameters self.mean, self.scale = parameters.chunk(2, dim=1) self.std = nn.functional.softplus(self.scale) + 1e-4 self.var = self.std * self.std self.logvar = torch.log(self.var) self.deterministic = deterministic def sample(self, generator: torch.Generator | None = None) -> torch.Tensor: # make sure sample is on the same device as the parameters and has same dtype sample = torch.randn( self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype, ) x = self.mean + self.std * sample return x def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean() else: normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var var_ratio = self.var / other.var logvar_diff = self.logvar - other.logvar kl = normalized_diff + var_ratio + logvar_diff - 1 kl = kl.sum(1).mean() return kl class AceStepVAE(nn.Module): """Audio VAE for ACE-Step (AutoencoderOobleck architecture). Encodes audio waveform → latent, decodes latent → audio waveform. Uses Snake1d activations and weight-normalized convolutions. """ def __init__( self, encoder_hidden_size: int = 128, downsampling_ratios: list = None, channel_multiples: list = None, decoder_channels: int = 128, decoder_input_channels: int = 64, audio_channels: int = 2, sampling_rate: int = 48000, ): super().__init__() downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10] channel_multiples = channel_multiples or [1, 2, 4, 8, 16] upsampling_ratios = downsampling_ratios[::-1] self.encoder = OobleckEncoder( encoder_hidden_size=encoder_hidden_size, audio_channels=audio_channels, downsampling_ratios=downsampling_ratios, channel_multiples=channel_multiples, ) self.decoder = OobleckDecoder( channels=decoder_channels, input_channels=decoder_input_channels, audio_channels=audio_channels, upsampling_ratios=upsampling_ratios, channel_multiples=channel_multiples, ) self.sampling_rate = sampling_rate def encode(self, x: torch.Tensor) -> torch.Tensor: """Audio waveform [B, audio_channels, T] → latent [B, decoder_input_channels, T'].""" h = self.encoder(x) output = OobleckDiagonalGaussianDistribution(h).sample() return output def decode(self, z: torch.Tensor) -> torch.Tensor: """Latent [B, decoder_input_channels, T] → audio waveform [B, audio_channels, T'].""" return self.decoder(z) def forward(self, sample: torch.Tensor) -> torch.Tensor: """Full round-trip: encode → decode.""" z = self.encode(sample) return self.decode(z) def remove_weight_norm(self): """Remove weight normalization from all conv layers (for export/inference).""" for module in self.modules(): if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d): remove_weight_norm(module)