import math from typing import List, Union import numpy as np import torch from torch import nn from torch.nn.utils import weight_norm import torch.nn.functional as F from einops import rearrange def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) # Scripting this brings model speed up 1.4x @torch.jit.script def snake(x, alpha): shape = x.shape x = x.reshape(shape[0], shape[1], -1) x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) x = x.reshape(shape) return x class Snake1d(nn.Module): def __init__(self, channels): super().__init__() self.alpha = nn.Parameter(torch.ones(1, channels, 1)) def forward(self, x): return snake(x, self.alpha) class VectorQuantize(nn.Module): """ Implementation of VQ similar to Karpathy's repo: https://github.com/karpathy/deep-vector-quantization Additionally uses following tricks from Improved VQGAN (https://arxiv.org/pdf/2110.04627.pdf): 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space for improved codebook usage 2. l2-normalized codes: Converts euclidean distance to cosine similarity which improves training stability """ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): super().__init__() self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) self.codebook = nn.Embedding(codebook_size, codebook_dim) def forward(self, z): """Quantized the input tensor using a fixed codebook and returns the corresponding codebook vectors Parameters ---------- z : Tensor[B x D x T] Returns ------- Tensor[B x D x T] Quantized continuous representation of input Tensor[1] Commitment loss to train encoder to predict vectors closer to codebook entries Tensor[1] Codebook loss to update the codebook Tensor[B x T] Codebook indices (quantized discrete representation of input) Tensor[B x D x T] Projected latents (continuous representation of input before quantization) """ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space z_e = self.in_proj(z) # z_e : (B x D x T) z_q, indices = self.decode_latents(z_e) commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) z_q = ( z_e + (z_q - z_e).detach() ) # noop in forward pass, straight-through gradient estimator in backward pass z_q = self.out_proj(z_q) return z_q, commitment_loss, codebook_loss, indices, z_e def embed_code(self, embed_id): return F.embedding(embed_id, self.codebook.weight) def decode_code(self, embed_id): return self.embed_code(embed_id).transpose(1, 2) def decode_latents(self, latents): encodings = rearrange(latents, "b d t -> (b t) d") codebook = self.codebook.weight # codebook: (N x D) # L2 normalize encodings and codebook (ViT-VQGAN) encodings = F.normalize(encodings) codebook = F.normalize(codebook) # Compute euclidean distance with codebook dist = ( encodings.pow(2).sum(1, keepdim=True) - 2 * encodings @ codebook.t() + codebook.pow(2).sum(1, keepdim=True).t() ) indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) z_q = self.decode_code(indices) return z_q, indices class ResidualVectorQuantize(nn.Module): """ Introduced in SoundStream: An end2end neural audio codec https://arxiv.org/abs/2107.03312 """ def __init__( self, input_dim: int = 512, n_codebooks: int = 9, codebook_size: int = 1024, codebook_dim: Union[int, list] = 8, quantizer_dropout: float = 0.0, ): super().__init__() if isinstance(codebook_dim, int): codebook_dim = [codebook_dim for _ in range(n_codebooks)] self.n_codebooks = n_codebooks self.codebook_dim = codebook_dim self.codebook_size = codebook_size self.quantizers = nn.ModuleList( [ VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks) ] ) self.quantizer_dropout = quantizer_dropout def forward(self, z, n_quantizers: int = None): """Quantized the input tensor using a fixed set of `n` codebooks and returns the corresponding codebook vectors Parameters ---------- z : Tensor[B x D x T] n_quantizers : int, optional No. of quantizers to use (n_quantizers < self.n_codebooks ex: for quantizer dropout) Note: if `self.quantizer_dropout` is True, this argument is ignored when in training mode, and a random number of quantizers is used. Returns ------- dict A dictionary with the following keys: "z" : Tensor[B x D x T] Quantized continuous representation of input "codes" : Tensor[B x N x T] Codebook indices for each codebook (quantized discrete representation of input) "latents" : Tensor[B x N*D x T] Projected latents (continuous representation of input before quantization) "vq/commitment_loss" : Tensor[1] Commitment loss to train encoder to predict vectors closer to codebook entries "vq/codebook_loss" : Tensor[1] Codebook loss to update the codebook """ z_q = 0 residual = z commitment_loss = 0 codebook_loss = 0 codebook_indices = [] latents = [] if n_quantizers is None: n_quantizers = self.n_codebooks if self.training: n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) n_dropout = int(z.shape[0] * self.quantizer_dropout) n_quantizers[:n_dropout] = dropout[:n_dropout] n_quantizers = n_quantizers.to(z.device) for i, quantizer in enumerate(self.quantizers): if self.training is False and i >= n_quantizers: break z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( residual ) # Create mask to apply quantizer dropout mask = ( torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers ) z_q = z_q + z_q_i * mask[:, None, None] residual = residual - z_q_i # Sum losses commitment_loss += (commitment_loss_i * mask).mean() codebook_loss += (codebook_loss_i * mask).mean() codebook_indices.append(indices_i) latents.append(z_e_i) codes = torch.stack(codebook_indices, dim=1) latents = torch.cat(latents, dim=1) return z_q, codes, latents, commitment_loss, codebook_loss def from_codes(self, codes: torch.Tensor): """Given the quantized codes, reconstruct the continuous representation Parameters ---------- codes : Tensor[B x N x T] Quantized discrete representation of input Returns ------- Tensor[B x D x T] Quantized continuous representation of input """ z_q = 0.0 z_p = [] n_codebooks = codes.shape[1] for i in range(n_codebooks): z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) z_p.append(z_p_i) z_q_i = self.quantizers[i].out_proj(z_p_i) z_q = z_q + z_q_i return z_q, torch.cat(z_p, dim=1), codes def from_latents(self, latents: torch.Tensor): """Given the unquantized latents, reconstruct the continuous representation after quantization. Parameters ---------- latents : Tensor[B x N x T] Continuous representation of input after projection Returns ------- Tensor[B x D x T] Quantized representation of full-projected space Tensor[B x D x T] Quantized representation of latent space """ z_q = 0 z_p = [] codes = [] dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ 0 ] for i in range(n_codebooks): j, k = dims[i], dims[i + 1] z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) z_p.append(z_p_i) codes.append(codes_i) z_q_i = self.quantizers[i].out_proj(z_p_i) z_q = z_q + z_q_i return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) class AbstractDistribution: def sample(self): raise NotImplementedError() def mode(self): raise NotImplementedError() class DiracDistribution(AbstractDistribution): def __init__(self, value): self.value = value def sample(self): return self.value def mode(self): return self.value class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) def sample(self): x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) return x def kl(self, other=None): if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return 0.5 * torch.mean( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2], ) else: return 0.5 * torch.mean( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2], ) def nll(self, sample, dims=[1, 2]): if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims, ) def mode(self): return self.mean def normal_kl(mean1, logvar1, mean2, logvar2): """ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 Compute the KL divergence between two gaussians. Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break assert tensor is not None, "at least one argument must be a Tensor" # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] return 0.5 * ( -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) def init_weights(m): if isinstance(m, nn.Conv1d): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) class ResidualUnit(nn.Module): def __init__(self, dim: int = 16, dilation: int = 1): super().__init__() pad = ((7 - 1) * dilation) // 2 self.block = nn.Sequential( Snake1d(dim), WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), Snake1d(dim), WNConv1d(dim, dim, kernel_size=1), ) def forward(self, x): y = self.block(x) pad = (x.shape[-1] - y.shape[-1]) // 2 if pad > 0: x = x[..., pad:-pad] return x + y class EncoderBlock(nn.Module): def __init__(self, dim: int = 16, stride: int = 1): super().__init__() self.block = nn.Sequential( ResidualUnit(dim // 2, dilation=1), ResidualUnit(dim // 2, dilation=3), ResidualUnit(dim // 2, dilation=9), Snake1d(dim // 2), WNConv1d( dim // 2, dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), ), ) def forward(self, x): return self.block(x) class Encoder(nn.Module): def __init__( self, d_model: int = 64, strides: list = [2, 4, 8, 8], d_latent: int = 64, ): super().__init__() # Create first convolution self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] # Create EncoderBlocks that double channels as they downsample by `stride` for stride in strides: d_model *= 2 self.block += [EncoderBlock(d_model, stride=stride)] # Create last convolution self.block += [ Snake1d(d_model), WNConv1d(d_model, d_latent, kernel_size=3, padding=1), ] # Wrap black into nn.Sequential self.block = nn.Sequential(*self.block) self.enc_dim = d_model def forward(self, x): return self.block(x) class DecoderBlock(nn.Module): def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): super().__init__() self.block = nn.Sequential( Snake1d(input_dim), WNConvTranspose1d( input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), output_padding=stride % 2, ), ResidualUnit(output_dim, dilation=1), ResidualUnit(output_dim, dilation=3), ResidualUnit(output_dim, dilation=9), ) def forward(self, x): return self.block(x) class Decoder(nn.Module): def __init__( self, input_channel, channels, rates, d_out: int = 1, ): super().__init__() # Add first conv layer layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] # Add upsampling + MRF blocks for i, stride in enumerate(rates): input_dim = channels // 2**i output_dim = channels // 2 ** (i + 1) layers += [DecoderBlock(input_dim, output_dim, stride)] # Add final conv layer layers += [ Snake1d(output_dim), WNConv1d(output_dim, d_out, kernel_size=7, padding=3), nn.Tanh(), ] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class DacVAE(nn.Module): def __init__( self, encoder_dim: int = 128, encoder_rates: List[int] = [2, 3, 4, 5, 8], latent_dim: int = 128, decoder_dim: int = 2048, decoder_rates: List[int] = [8, 5, 4, 3, 2], n_codebooks: int = 9, codebook_size: int = 1024, codebook_dim: Union[int, list] = 8, quantizer_dropout: bool = False, sample_rate: int = 48000, continuous: bool = True, use_weight_norm: bool = False, ): super().__init__() self.encoder_dim = encoder_dim self.encoder_rates = encoder_rates self.decoder_dim = decoder_dim self.decoder_rates = decoder_rates self.sample_rate = sample_rate self.continuous = continuous self.use_weight_norm = use_weight_norm if latent_dim is None: latent_dim = encoder_dim * (2 ** len(encoder_rates)) self.latent_dim = latent_dim self.hop_length = np.prod(encoder_rates) self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) if not continuous: self.n_codebooks = n_codebooks self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.quantizer = ResidualVectorQuantize( input_dim=latent_dim, n_codebooks=n_codebooks, codebook_size=codebook_size, codebook_dim=codebook_dim, quantizer_dropout=quantizer_dropout, ) else: self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1) self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1) self.decoder = Decoder( latent_dim, decoder_dim, decoder_rates, ) self.sample_rate = sample_rate self.apply(init_weights) self.delay = self.get_delay() if not self.use_weight_norm: self.remove_weight_norm() def get_delay(self): # Any number works here, delay is invariant to input length l_out = self.get_output_length(0) L = l_out layers = [] for layer in self.modules(): if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): layers.append(layer) for layer in reversed(layers): d = layer.dilation[0] k = layer.kernel_size[0] s = layer.stride[0] if isinstance(layer, nn.ConvTranspose1d): L = ((L - d * (k - 1) - 1) / s) + 1 elif isinstance(layer, nn.Conv1d): L = (L - 1) * s + d * (k - 1) + 1 L = math.ceil(L) l_in = L return (l_in - l_out) // 2 def get_output_length(self, input_length): L = input_length # Calculate output length for layer in self.modules(): if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): d = layer.dilation[0] k = layer.kernel_size[0] s = layer.stride[0] if isinstance(layer, nn.Conv1d): L = ((L - d * (k - 1) - 1) / s) + 1 elif isinstance(layer, nn.ConvTranspose1d): L = (L - 1) * s + d * (k - 1) + 1 L = math.floor(L) return L @property def dtype(self): """Get the dtype of the model parameters.""" # Return the dtype of the first parameter found for param in self.parameters(): return param.dtype return torch.float32 # fallback @property def device(self): """Get the device of the model parameters.""" # Return the device of the first parameter found for param in self.parameters(): return param.device return torch.device('cpu') # fallback def preprocess(self, audio_data, sample_rate): if sample_rate is None: sample_rate = self.sample_rate assert sample_rate == self.sample_rate length = audio_data.shape[-1] right_pad = math.ceil(length / self.hop_length) * self.hop_length - length audio_data = nn.functional.pad(audio_data, (0, right_pad)) return audio_data def encode( self, audio_data: torch.Tensor, n_quantizers: int = None, ): """Encode given audio data and return quantized latent codes Parameters ---------- audio_data : Tensor[B x 1 x T] Audio data to encode n_quantizers : int, optional Number of quantizers to use, by default None If None, all quantizers are used. Returns ------- dict A dictionary with the following keys: "z" : Tensor[B x D x T] Quantized continuous representation of input "codes" : Tensor[B x N x T] Codebook indices for each codebook (quantized discrete representation of input) "latents" : Tensor[B x N*D x T] Projected latents (continuous representation of input before quantization) "vq/commitment_loss" : Tensor[1] Commitment loss to train encoder to predict vectors closer to codebook entries "vq/codebook_loss" : Tensor[1] Codebook loss to update the codebook "length" : int Number of samples in input audio """ z = self.encoder(audio_data) # [B x D x T] if not self.continuous: z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) else: z = self.quant_conv(z) # [B x 2D x T] z = DiagonalGaussianDistribution(z) codes, latents, commitment_loss, codebook_loss = None, None, 0, 0 return z, codes, latents, commitment_loss, codebook_loss def decode(self, z: torch.Tensor): """Decode given latent codes and return audio data Parameters ---------- z : Tensor[B x D x T] Quantized continuous representation of input length : int, optional Number of samples in output audio, by default None Returns ------- dict A dictionary with the following keys: "audio" : Tensor[B x 1 x length] Decoded audio data. """ if not self.continuous: audio = self.decoder(z) else: z = self.post_quant_conv(z) audio = self.decoder(z) return audio def forward( self, audio_data: torch.Tensor, sample_rate: int = None, n_quantizers: int = None, ): """Model forward pass Parameters ---------- audio_data : Tensor[B x 1 x T] Audio data to encode sample_rate : int, optional Sample rate of audio data in Hz, by default None If None, defaults to `self.sample_rate` n_quantizers : int, optional Number of quantizers to use, by default None. If None, all quantizers are used. Returns ------- dict A dictionary with the following keys: "z" : Tensor[B x D x T] Quantized continuous representation of input "codes" : Tensor[B x N x T] Codebook indices for each codebook (quantized discrete representation of input) "latents" : Tensor[B x N*D x T] Projected latents (continuous representation of input before quantization) "vq/commitment_loss" : Tensor[1] Commitment loss to train encoder to predict vectors closer to codebook entries "vq/codebook_loss" : Tensor[1] Codebook loss to update the codebook "length" : int Number of samples in input audio "audio" : Tensor[B x 1 x length] Decoded audio data. """ length = audio_data.shape[-1] audio_data = self.preprocess(audio_data, sample_rate) if not self.continuous: z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) x = self.decode(z) return { "audio": x[..., :length], "z": z, "codes": codes, "latents": latents, "vq/commitment_loss": commitment_loss, "vq/codebook_loss": codebook_loss, } else: posterior, _, _, _, _ = self.encode(audio_data, n_quantizers) z = posterior.sample() x = self.decode(z) kl_loss = posterior.kl() kl_loss = kl_loss.mean() return { "audio": x[..., :length], "z": z, "kl_loss": kl_loss, } def remove_weight_norm(self): """ Remove weight_norm from all modules in the model. This fuses the weight_g and weight_v parameters into a single weight parameter. Should be called before inference for better performance. Returns: self: The model with weight_norm removed """ from torch.nn.utils import remove_weight_norm num_removed = 0 for name, module in list(self.named_modules()): if hasattr(module, "_forward_pre_hooks"): for hook_id, hook in list(module._forward_pre_hooks.items()): if "WeightNorm" in str(type(hook)): try: remove_weight_norm(module) num_removed += 1 # print(f"Removed weight_norm from: {name}") except ValueError as e: print(f"Failed to remove weight_norm from {name}: {e}") if num_removed > 0: # print(f"Successfully removed weight_norm from {num_removed} modules") self.use_weight_norm = False else: print("No weight_norm found in the model") return self