mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
* support mova inference * mova media_io * add unified audio_video api & fix bug of mono audio input for ltx * support mova train * mova docs * fix bug
797 lines
26 KiB
Python
797 lines
26 KiB
Python
import math
|
|
from typing import List, Union
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn.utils import weight_norm
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
|
|
def WNConv1d(*args, **kwargs):
|
|
return weight_norm(nn.Conv1d(*args, **kwargs))
|
|
|
|
|
|
def WNConvTranspose1d(*args, **kwargs):
|
|
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
|
|
|
|
|
# Scripting this brings model speed up 1.4x
|
|
@torch.jit.script
|
|
def snake(x, alpha):
|
|
shape = x.shape
|
|
x = x.reshape(shape[0], shape[1], -1)
|
|
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
|
x = x.reshape(shape)
|
|
return x
|
|
|
|
|
|
class Snake1d(nn.Module):
|
|
def __init__(self, channels):
|
|
super().__init__()
|
|
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
|
|
|
def forward(self, x):
|
|
return snake(x, self.alpha)
|
|
|
|
|
|
class VectorQuantize(nn.Module):
|
|
"""
|
|
Implementation of VQ similar to Karpathy's repo:
|
|
https://github.com/karpathy/deep-vector-quantization
|
|
Additionally uses following tricks from Improved VQGAN
|
|
(https://arxiv.org/pdf/2110.04627.pdf):
|
|
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
|
for improved codebook usage
|
|
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
|
improves training stability
|
|
"""
|
|
|
|
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
|
super().__init__()
|
|
self.codebook_size = codebook_size
|
|
self.codebook_dim = codebook_dim
|
|
|
|
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
|
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
|
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
|
|
|
def forward(self, z):
|
|
"""Quantized the input tensor using a fixed codebook and returns
|
|
the corresponding codebook vectors
|
|
|
|
Parameters
|
|
----------
|
|
z : Tensor[B x D x T]
|
|
|
|
Returns
|
|
-------
|
|
Tensor[B x D x T]
|
|
Quantized continuous representation of input
|
|
Tensor[1]
|
|
Commitment loss to train encoder to predict vectors closer to codebook
|
|
entries
|
|
Tensor[1]
|
|
Codebook loss to update the codebook
|
|
Tensor[B x T]
|
|
Codebook indices (quantized discrete representation of input)
|
|
Tensor[B x D x T]
|
|
Projected latents (continuous representation of input before quantization)
|
|
"""
|
|
|
|
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
|
z_e = self.in_proj(z) # z_e : (B x D x T)
|
|
z_q, indices = self.decode_latents(z_e)
|
|
|
|
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
|
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
|
|
|
z_q = (
|
|
z_e + (z_q - z_e).detach()
|
|
) # noop in forward pass, straight-through gradient estimator in backward pass
|
|
|
|
z_q = self.out_proj(z_q)
|
|
|
|
return z_q, commitment_loss, codebook_loss, indices, z_e
|
|
|
|
def embed_code(self, embed_id):
|
|
return F.embedding(embed_id, self.codebook.weight)
|
|
|
|
def decode_code(self, embed_id):
|
|
return self.embed_code(embed_id).transpose(1, 2)
|
|
|
|
def decode_latents(self, latents):
|
|
encodings = rearrange(latents, "b d t -> (b t) d")
|
|
codebook = self.codebook.weight # codebook: (N x D)
|
|
|
|
# L2 normalize encodings and codebook (ViT-VQGAN)
|
|
encodings = F.normalize(encodings)
|
|
codebook = F.normalize(codebook)
|
|
|
|
# Compute euclidean distance with codebook
|
|
dist = (
|
|
encodings.pow(2).sum(1, keepdim=True)
|
|
- 2 * encodings @ codebook.t()
|
|
+ codebook.pow(2).sum(1, keepdim=True).t()
|
|
)
|
|
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
|
z_q = self.decode_code(indices)
|
|
return z_q, indices
|
|
|
|
|
|
class ResidualVectorQuantize(nn.Module):
|
|
"""
|
|
Introduced in SoundStream: An end2end neural audio codec
|
|
https://arxiv.org/abs/2107.03312
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int = 512,
|
|
n_codebooks: int = 9,
|
|
codebook_size: int = 1024,
|
|
codebook_dim: Union[int, list] = 8,
|
|
quantizer_dropout: float = 0.0,
|
|
):
|
|
super().__init__()
|
|
if isinstance(codebook_dim, int):
|
|
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
|
|
|
self.n_codebooks = n_codebooks
|
|
self.codebook_dim = codebook_dim
|
|
self.codebook_size = codebook_size
|
|
|
|
self.quantizers = nn.ModuleList(
|
|
[
|
|
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
|
for i in range(n_codebooks)
|
|
]
|
|
)
|
|
self.quantizer_dropout = quantizer_dropout
|
|
|
|
def forward(self, z, n_quantizers: int = None):
|
|
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
|
the corresponding codebook vectors
|
|
Parameters
|
|
----------
|
|
z : Tensor[B x D x T]
|
|
n_quantizers : int, optional
|
|
No. of quantizers to use
|
|
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
|
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
|
when in training mode, and a random number of quantizers is used.
|
|
Returns
|
|
-------
|
|
dict
|
|
A dictionary with the following keys:
|
|
|
|
"z" : Tensor[B x D x T]
|
|
Quantized continuous representation of input
|
|
"codes" : Tensor[B x N x T]
|
|
Codebook indices for each codebook
|
|
(quantized discrete representation of input)
|
|
"latents" : Tensor[B x N*D x T]
|
|
Projected latents (continuous representation of input before quantization)
|
|
"vq/commitment_loss" : Tensor[1]
|
|
Commitment loss to train encoder to predict vectors closer to codebook
|
|
entries
|
|
"vq/codebook_loss" : Tensor[1]
|
|
Codebook loss to update the codebook
|
|
"""
|
|
z_q = 0
|
|
residual = z
|
|
commitment_loss = 0
|
|
codebook_loss = 0
|
|
|
|
codebook_indices = []
|
|
latents = []
|
|
|
|
if n_quantizers is None:
|
|
n_quantizers = self.n_codebooks
|
|
if self.training:
|
|
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
|
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
|
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
|
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
|
n_quantizers = n_quantizers.to(z.device)
|
|
|
|
for i, quantizer in enumerate(self.quantizers):
|
|
if self.training is False and i >= n_quantizers:
|
|
break
|
|
|
|
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
|
residual
|
|
)
|
|
|
|
# Create mask to apply quantizer dropout
|
|
mask = (
|
|
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
|
)
|
|
z_q = z_q + z_q_i * mask[:, None, None]
|
|
residual = residual - z_q_i
|
|
|
|
# Sum losses
|
|
commitment_loss += (commitment_loss_i * mask).mean()
|
|
codebook_loss += (codebook_loss_i * mask).mean()
|
|
|
|
codebook_indices.append(indices_i)
|
|
latents.append(z_e_i)
|
|
|
|
codes = torch.stack(codebook_indices, dim=1)
|
|
latents = torch.cat(latents, dim=1)
|
|
|
|
return z_q, codes, latents, commitment_loss, codebook_loss
|
|
|
|
def from_codes(self, codes: torch.Tensor):
|
|
"""Given the quantized codes, reconstruct the continuous representation
|
|
Parameters
|
|
----------
|
|
codes : Tensor[B x N x T]
|
|
Quantized discrete representation of input
|
|
Returns
|
|
-------
|
|
Tensor[B x D x T]
|
|
Quantized continuous representation of input
|
|
"""
|
|
z_q = 0.0
|
|
z_p = []
|
|
n_codebooks = codes.shape[1]
|
|
for i in range(n_codebooks):
|
|
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
|
z_p.append(z_p_i)
|
|
|
|
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
|
z_q = z_q + z_q_i
|
|
return z_q, torch.cat(z_p, dim=1), codes
|
|
|
|
def from_latents(self, latents: torch.Tensor):
|
|
"""Given the unquantized latents, reconstruct the
|
|
continuous representation after quantization.
|
|
|
|
Parameters
|
|
----------
|
|
latents : Tensor[B x N x T]
|
|
Continuous representation of input after projection
|
|
|
|
Returns
|
|
-------
|
|
Tensor[B x D x T]
|
|
Quantized representation of full-projected space
|
|
Tensor[B x D x T]
|
|
Quantized representation of latent space
|
|
"""
|
|
z_q = 0
|
|
z_p = []
|
|
codes = []
|
|
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
|
|
|
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
|
0
|
|
]
|
|
for i in range(n_codebooks):
|
|
j, k = dims[i], dims[i + 1]
|
|
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
|
z_p.append(z_p_i)
|
|
codes.append(codes_i)
|
|
|
|
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
|
z_q = z_q + z_q_i
|
|
|
|
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
|
|
|
|
|
class AbstractDistribution:
|
|
def sample(self):
|
|
raise NotImplementedError()
|
|
|
|
def mode(self):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class DiracDistribution(AbstractDistribution):
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def sample(self):
|
|
return self.value
|
|
|
|
def mode(self):
|
|
return self.value
|
|
|
|
|
|
class DiagonalGaussianDistribution(object):
|
|
def __init__(self, parameters, deterministic=False):
|
|
self.parameters = parameters
|
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
|
self.deterministic = deterministic
|
|
self.std = torch.exp(0.5 * self.logvar)
|
|
self.var = torch.exp(self.logvar)
|
|
if self.deterministic:
|
|
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
|
|
|
def sample(self):
|
|
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
|
return x
|
|
|
|
def kl(self, other=None):
|
|
if self.deterministic:
|
|
return torch.Tensor([0.0])
|
|
else:
|
|
if other is None:
|
|
return 0.5 * torch.mean(
|
|
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
|
dim=[1, 2],
|
|
)
|
|
else:
|
|
return 0.5 * torch.mean(
|
|
torch.pow(self.mean - other.mean, 2) / other.var
|
|
+ self.var / other.var
|
|
- 1.0
|
|
- self.logvar
|
|
+ other.logvar,
|
|
dim=[1, 2],
|
|
)
|
|
|
|
def nll(self, sample, dims=[1, 2]):
|
|
if self.deterministic:
|
|
return torch.Tensor([0.0])
|
|
logtwopi = np.log(2.0 * np.pi)
|
|
return 0.5 * torch.sum(
|
|
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
|
dim=dims,
|
|
)
|
|
|
|
def mode(self):
|
|
return self.mean
|
|
|
|
|
|
def normal_kl(mean1, logvar1, mean2, logvar2):
|
|
"""
|
|
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
|
Compute the KL divergence between two gaussians.
|
|
Shapes are automatically broadcasted, so batches can be compared to
|
|
scalars, among other use cases.
|
|
"""
|
|
tensor = None
|
|
for obj in (mean1, logvar1, mean2, logvar2):
|
|
if isinstance(obj, torch.Tensor):
|
|
tensor = obj
|
|
break
|
|
assert tensor is not None, "at least one argument must be a Tensor"
|
|
|
|
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
|
# Tensors, but it does not work for torch.exp().
|
|
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
|
|
|
return 0.5 * (
|
|
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
|
)
|
|
|
|
|
|
def init_weights(m):
|
|
if isinstance(m, nn.Conv1d):
|
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
class ResidualUnit(nn.Module):
|
|
def __init__(self, dim: int = 16, dilation: int = 1):
|
|
super().__init__()
|
|
pad = ((7 - 1) * dilation) // 2
|
|
self.block = nn.Sequential(
|
|
Snake1d(dim),
|
|
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
|
Snake1d(dim),
|
|
WNConv1d(dim, dim, kernel_size=1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
y = self.block(x)
|
|
pad = (x.shape[-1] - y.shape[-1]) // 2
|
|
if pad > 0:
|
|
x = x[..., pad:-pad]
|
|
return x + y
|
|
|
|
|
|
class EncoderBlock(nn.Module):
|
|
def __init__(self, dim: int = 16, stride: int = 1):
|
|
super().__init__()
|
|
self.block = nn.Sequential(
|
|
ResidualUnit(dim // 2, dilation=1),
|
|
ResidualUnit(dim // 2, dilation=3),
|
|
ResidualUnit(dim // 2, dilation=9),
|
|
Snake1d(dim // 2),
|
|
WNConv1d(
|
|
dim // 2,
|
|
dim,
|
|
kernel_size=2 * stride,
|
|
stride=stride,
|
|
padding=math.ceil(stride / 2),
|
|
),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
d_model: int = 64,
|
|
strides: list = [2, 4, 8, 8],
|
|
d_latent: int = 64,
|
|
):
|
|
super().__init__()
|
|
# Create first convolution
|
|
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
|
|
|
# Create EncoderBlocks that double channels as they downsample by `stride`
|
|
for stride in strides:
|
|
d_model *= 2
|
|
self.block += [EncoderBlock(d_model, stride=stride)]
|
|
|
|
# Create last convolution
|
|
self.block += [
|
|
Snake1d(d_model),
|
|
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
|
|
]
|
|
|
|
# Wrap black into nn.Sequential
|
|
self.block = nn.Sequential(*self.block)
|
|
self.enc_dim = d_model
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
|
|
class DecoderBlock(nn.Module):
|
|
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
|
super().__init__()
|
|
self.block = nn.Sequential(
|
|
Snake1d(input_dim),
|
|
WNConvTranspose1d(
|
|
input_dim,
|
|
output_dim,
|
|
kernel_size=2 * stride,
|
|
stride=stride,
|
|
padding=math.ceil(stride / 2),
|
|
output_padding=stride % 2,
|
|
),
|
|
ResidualUnit(output_dim, dilation=1),
|
|
ResidualUnit(output_dim, dilation=3),
|
|
ResidualUnit(output_dim, dilation=9),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_channel,
|
|
channels,
|
|
rates,
|
|
d_out: int = 1,
|
|
):
|
|
super().__init__()
|
|
|
|
# Add first conv layer
|
|
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
|
|
|
# Add upsampling + MRF blocks
|
|
for i, stride in enumerate(rates):
|
|
input_dim = channels // 2**i
|
|
output_dim = channels // 2 ** (i + 1)
|
|
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
|
|
|
# Add final conv layer
|
|
layers += [
|
|
Snake1d(output_dim),
|
|
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
|
nn.Tanh(),
|
|
]
|
|
|
|
self.model = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
|
|
class DacVAE(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
encoder_dim: int = 128,
|
|
encoder_rates: List[int] = [2, 3, 4, 5, 8],
|
|
latent_dim: int = 128,
|
|
decoder_dim: int = 2048,
|
|
decoder_rates: List[int] = [8, 5, 4, 3, 2],
|
|
n_codebooks: int = 9,
|
|
codebook_size: int = 1024,
|
|
codebook_dim: Union[int, list] = 8,
|
|
quantizer_dropout: bool = False,
|
|
sample_rate: int = 48000,
|
|
continuous: bool = True,
|
|
use_weight_norm: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
self.encoder_dim = encoder_dim
|
|
self.encoder_rates = encoder_rates
|
|
self.decoder_dim = decoder_dim
|
|
self.decoder_rates = decoder_rates
|
|
self.sample_rate = sample_rate
|
|
self.continuous = continuous
|
|
self.use_weight_norm = use_weight_norm
|
|
|
|
if latent_dim is None:
|
|
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
|
|
|
self.latent_dim = latent_dim
|
|
|
|
self.hop_length = np.prod(encoder_rates)
|
|
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
|
|
|
|
if not continuous:
|
|
self.n_codebooks = n_codebooks
|
|
self.codebook_size = codebook_size
|
|
self.codebook_dim = codebook_dim
|
|
self.quantizer = ResidualVectorQuantize(
|
|
input_dim=latent_dim,
|
|
n_codebooks=n_codebooks,
|
|
codebook_size=codebook_size,
|
|
codebook_dim=codebook_dim,
|
|
quantizer_dropout=quantizer_dropout,
|
|
)
|
|
else:
|
|
self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
|
|
self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
|
|
|
|
self.decoder = Decoder(
|
|
latent_dim,
|
|
decoder_dim,
|
|
decoder_rates,
|
|
)
|
|
self.sample_rate = sample_rate
|
|
self.apply(init_weights)
|
|
|
|
self.delay = self.get_delay()
|
|
|
|
if not self.use_weight_norm:
|
|
self.remove_weight_norm()
|
|
|
|
def get_delay(self):
|
|
# Any number works here, delay is invariant to input length
|
|
l_out = self.get_output_length(0)
|
|
L = l_out
|
|
|
|
layers = []
|
|
for layer in self.modules():
|
|
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
|
layers.append(layer)
|
|
|
|
for layer in reversed(layers):
|
|
d = layer.dilation[0]
|
|
k = layer.kernel_size[0]
|
|
s = layer.stride[0]
|
|
|
|
if isinstance(layer, nn.ConvTranspose1d):
|
|
L = ((L - d * (k - 1) - 1) / s) + 1
|
|
elif isinstance(layer, nn.Conv1d):
|
|
L = (L - 1) * s + d * (k - 1) + 1
|
|
|
|
L = math.ceil(L)
|
|
|
|
l_in = L
|
|
|
|
return (l_in - l_out) // 2
|
|
|
|
def get_output_length(self, input_length):
|
|
L = input_length
|
|
# Calculate output length
|
|
for layer in self.modules():
|
|
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
|
d = layer.dilation[0]
|
|
k = layer.kernel_size[0]
|
|
s = layer.stride[0]
|
|
|
|
if isinstance(layer, nn.Conv1d):
|
|
L = ((L - d * (k - 1) - 1) / s) + 1
|
|
elif isinstance(layer, nn.ConvTranspose1d):
|
|
L = (L - 1) * s + d * (k - 1) + 1
|
|
|
|
L = math.floor(L)
|
|
return L
|
|
|
|
@property
|
|
def dtype(self):
|
|
"""Get the dtype of the model parameters."""
|
|
# Return the dtype of the first parameter found
|
|
for param in self.parameters():
|
|
return param.dtype
|
|
return torch.float32 # fallback
|
|
|
|
@property
|
|
def device(self):
|
|
"""Get the device of the model parameters."""
|
|
# Return the device of the first parameter found
|
|
for param in self.parameters():
|
|
return param.device
|
|
return torch.device('cpu') # fallback
|
|
|
|
def preprocess(self, audio_data, sample_rate):
|
|
if sample_rate is None:
|
|
sample_rate = self.sample_rate
|
|
assert sample_rate == self.sample_rate
|
|
|
|
length = audio_data.shape[-1]
|
|
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
|
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
|
|
|
return audio_data
|
|
|
|
def encode(
|
|
self,
|
|
audio_data: torch.Tensor,
|
|
n_quantizers: int = None,
|
|
):
|
|
"""Encode given audio data and return quantized latent codes
|
|
|
|
Parameters
|
|
----------
|
|
audio_data : Tensor[B x 1 x T]
|
|
Audio data to encode
|
|
n_quantizers : int, optional
|
|
Number of quantizers to use, by default None
|
|
If None, all quantizers are used.
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
A dictionary with the following keys:
|
|
"z" : Tensor[B x D x T]
|
|
Quantized continuous representation of input
|
|
"codes" : Tensor[B x N x T]
|
|
Codebook indices for each codebook
|
|
(quantized discrete representation of input)
|
|
"latents" : Tensor[B x N*D x T]
|
|
Projected latents (continuous representation of input before quantization)
|
|
"vq/commitment_loss" : Tensor[1]
|
|
Commitment loss to train encoder to predict vectors closer to codebook
|
|
entries
|
|
"vq/codebook_loss" : Tensor[1]
|
|
Codebook loss to update the codebook
|
|
"length" : int
|
|
Number of samples in input audio
|
|
"""
|
|
z = self.encoder(audio_data) # [B x D x T]
|
|
if not self.continuous:
|
|
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
|
|
else:
|
|
z = self.quant_conv(z) # [B x 2D x T]
|
|
z = DiagonalGaussianDistribution(z)
|
|
codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
|
|
|
|
return z, codes, latents, commitment_loss, codebook_loss
|
|
|
|
def decode(self, z: torch.Tensor):
|
|
"""Decode given latent codes and return audio data
|
|
|
|
Parameters
|
|
----------
|
|
z : Tensor[B x D x T]
|
|
Quantized continuous representation of input
|
|
length : int, optional
|
|
Number of samples in output audio, by default None
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
A dictionary with the following keys:
|
|
"audio" : Tensor[B x 1 x length]
|
|
Decoded audio data.
|
|
"""
|
|
if not self.continuous:
|
|
audio = self.decoder(z)
|
|
else:
|
|
z = self.post_quant_conv(z)
|
|
audio = self.decoder(z)
|
|
|
|
return audio
|
|
|
|
def forward(
|
|
self,
|
|
audio_data: torch.Tensor,
|
|
sample_rate: int = None,
|
|
n_quantizers: int = None,
|
|
):
|
|
"""Model forward pass
|
|
|
|
Parameters
|
|
----------
|
|
audio_data : Tensor[B x 1 x T]
|
|
Audio data to encode
|
|
sample_rate : int, optional
|
|
Sample rate of audio data in Hz, by default None
|
|
If None, defaults to `self.sample_rate`
|
|
n_quantizers : int, optional
|
|
Number of quantizers to use, by default None.
|
|
If None, all quantizers are used.
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
A dictionary with the following keys:
|
|
"z" : Tensor[B x D x T]
|
|
Quantized continuous representation of input
|
|
"codes" : Tensor[B x N x T]
|
|
Codebook indices for each codebook
|
|
(quantized discrete representation of input)
|
|
"latents" : Tensor[B x N*D x T]
|
|
Projected latents (continuous representation of input before quantization)
|
|
"vq/commitment_loss" : Tensor[1]
|
|
Commitment loss to train encoder to predict vectors closer to codebook
|
|
entries
|
|
"vq/codebook_loss" : Tensor[1]
|
|
Codebook loss to update the codebook
|
|
"length" : int
|
|
Number of samples in input audio
|
|
"audio" : Tensor[B x 1 x length]
|
|
Decoded audio data.
|
|
"""
|
|
length = audio_data.shape[-1]
|
|
audio_data = self.preprocess(audio_data, sample_rate)
|
|
if not self.continuous:
|
|
z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
|
|
|
|
x = self.decode(z)
|
|
return {
|
|
"audio": x[..., :length],
|
|
"z": z,
|
|
"codes": codes,
|
|
"latents": latents,
|
|
"vq/commitment_loss": commitment_loss,
|
|
"vq/codebook_loss": codebook_loss,
|
|
}
|
|
else:
|
|
posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
|
|
z = posterior.sample()
|
|
x = self.decode(z)
|
|
|
|
kl_loss = posterior.kl()
|
|
kl_loss = kl_loss.mean()
|
|
|
|
return {
|
|
"audio": x[..., :length],
|
|
"z": z,
|
|
"kl_loss": kl_loss,
|
|
}
|
|
|
|
def remove_weight_norm(self):
|
|
"""
|
|
Remove weight_norm from all modules in the model.
|
|
This fuses the weight_g and weight_v parameters into a single weight parameter.
|
|
Should be called before inference for better performance.
|
|
Returns:
|
|
self: The model with weight_norm removed
|
|
"""
|
|
from torch.nn.utils import remove_weight_norm
|
|
num_removed = 0
|
|
for name, module in list(self.named_modules()):
|
|
if hasattr(module, "_forward_pre_hooks"):
|
|
for hook_id, hook in list(module._forward_pre_hooks.items()):
|
|
if "WeightNorm" in str(type(hook)):
|
|
try:
|
|
remove_weight_norm(module)
|
|
num_removed += 1
|
|
# print(f"Removed weight_norm from: {name}")
|
|
except ValueError as e:
|
|
print(f"Failed to remove weight_norm from {name}: {e}")
|
|
if num_removed > 0:
|
|
# print(f"Successfully removed weight_norm from {num_removed} modules")
|
|
self.use_weight_norm = False
|
|
else:
|
|
print("No weight_norm found in the model")
|
|
return self
|