mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Mova (#1337)
* support mova inference * mova media_io * add unified audio_video api & fix bug of mono audio input for ltx * support mova train * mova docs * fix bug
This commit is contained in:
57
diffsynth/models/mova_audio_dit.py
Normal file
57
diffsynth/models/mova_audio_dit.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .wan_video_dit import WanModel, precompute_freqs_cis, sinusoidal_embedding_1d
|
||||
from einops import rearrange
|
||||
from ..core import gradient_checkpoint_forward
|
||||
|
||||
def precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0):
|
||||
f_freqs_cis = precompute_freqs_cis(dim, end, theta)
|
||||
return f_freqs_cis.chunk(3, dim=-1)
|
||||
|
||||
class MovaAudioDit(WanModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
head_dim = kwargs.get("dim", 1536) // kwargs.get("num_heads", 12)
|
||||
self.freqs = precompute_freqs_cis_1d(head_dim)
|
||||
self.patch_embedding = nn.Conv1d(
|
||||
kwargs.get("in_dim", 128), kwargs.get("dim", 1536), kernel_size=[1], stride=[1]
|
||||
)
|
||||
|
||||
def precompute_freqs_cis(self, dim: int, end: int = 16384, theta: float = 10000.0):
|
||||
self.f_freqs_cis = precompute_freqs_cis_1d(dim, end, theta)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
x, (f, ) = self.patchify(x)
|
||||
freqs = torch.cat([
|
||||
self.freqs[0][:f].view(f, -1).expand(f, -1),
|
||||
self.freqs[1][:f].view(f, -1).expand(f, -1),
|
||||
self.freqs[2][:f].view(f, -1).expand(f, -1),
|
||||
], dim=-1).reshape(f, 1, -1).to(x.device)
|
||||
|
||||
for block in self.blocks:
|
||||
x = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x, context, t_mod, freqs,
|
||||
)
|
||||
x = self.head(x, t)
|
||||
x = self.unpatchify(x, (f, ))
|
||||
return x
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||
return rearrange(
|
||||
x, 'b f (p c) -> b c (f p)',
|
||||
f=grid_size[0],
|
||||
p=self.patch_size[0]
|
||||
)
|
||||
796
diffsynth/models/mova_audio_vae.py
Normal file
796
diffsynth/models/mova_audio_vae.py
Normal file
@@ -0,0 +1,796 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
# Scripting this brings model speed up 1.4x
|
||||
@torch.jit.script
|
||||
def snake(x, alpha):
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return snake(x, self.alpha)
|
||||
|
||||
|
||||
class VectorQuantize(nn.Module):
|
||||
"""
|
||||
Implementation of VQ similar to Karpathy's repo:
|
||||
https://github.com/karpathy/deep-vector-quantization
|
||||
Additionally uses following tricks from Improved VQGAN
|
||||
(https://arxiv.org/pdf/2110.04627.pdf):
|
||||
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
||||
for improved codebook usage
|
||||
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
||||
improves training stability
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
||||
super().__init__()
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
|
||||
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
||||
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
||||
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
||||
|
||||
def forward(self, z):
|
||||
"""Quantized the input tensor using a fixed codebook and returns
|
||||
the corresponding codebook vectors
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
Tensor[B x T]
|
||||
Codebook indices (quantized discrete representation of input)
|
||||
Tensor[B x D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"""
|
||||
|
||||
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
||||
z_e = self.in_proj(z) # z_e : (B x D x T)
|
||||
z_q, indices = self.decode_latents(z_e)
|
||||
|
||||
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
||||
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
||||
|
||||
z_q = (
|
||||
z_e + (z_q - z_e).detach()
|
||||
) # noop in forward pass, straight-through gradient estimator in backward pass
|
||||
|
||||
z_q = self.out_proj(z_q)
|
||||
|
||||
return z_q, commitment_loss, codebook_loss, indices, z_e
|
||||
|
||||
def embed_code(self, embed_id):
|
||||
return F.embedding(embed_id, self.codebook.weight)
|
||||
|
||||
def decode_code(self, embed_id):
|
||||
return self.embed_code(embed_id).transpose(1, 2)
|
||||
|
||||
def decode_latents(self, latents):
|
||||
encodings = rearrange(latents, "b d t -> (b t) d")
|
||||
codebook = self.codebook.weight # codebook: (N x D)
|
||||
|
||||
# L2 normalize encodings and codebook (ViT-VQGAN)
|
||||
encodings = F.normalize(encodings)
|
||||
codebook = F.normalize(codebook)
|
||||
|
||||
# Compute euclidean distance with codebook
|
||||
dist = (
|
||||
encodings.pow(2).sum(1, keepdim=True)
|
||||
- 2 * encodings @ codebook.t()
|
||||
+ codebook.pow(2).sum(1, keepdim=True).t()
|
||||
)
|
||||
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
||||
z_q = self.decode_code(indices)
|
||||
return z_q, indices
|
||||
|
||||
|
||||
class ResidualVectorQuantize(nn.Module):
|
||||
"""
|
||||
Introduced in SoundStream: An end2end neural audio codec
|
||||
https://arxiv.org/abs/2107.03312
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 512,
|
||||
n_codebooks: int = 9,
|
||||
codebook_size: int = 1024,
|
||||
codebook_dim: Union[int, list] = 8,
|
||||
quantizer_dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(codebook_dim, int):
|
||||
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
||||
|
||||
self.n_codebooks = n_codebooks
|
||||
self.codebook_dim = codebook_dim
|
||||
self.codebook_size = codebook_size
|
||||
|
||||
self.quantizers = nn.ModuleList(
|
||||
[
|
||||
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
||||
for i in range(n_codebooks)
|
||||
]
|
||||
)
|
||||
self.quantizer_dropout = quantizer_dropout
|
||||
|
||||
def forward(self, z, n_quantizers: int = None):
|
||||
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
||||
the corresponding codebook vectors
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
n_quantizers : int, optional
|
||||
No. of quantizers to use
|
||||
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
||||
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
||||
when in training mode, and a random number of quantizers is used.
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
|
||||
"z" : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"codes" : Tensor[B x N x T]
|
||||
Codebook indices for each codebook
|
||||
(quantized discrete representation of input)
|
||||
"latents" : Tensor[B x N*D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"vq/commitment_loss" : Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
"vq/codebook_loss" : Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
"""
|
||||
z_q = 0
|
||||
residual = z
|
||||
commitment_loss = 0
|
||||
codebook_loss = 0
|
||||
|
||||
codebook_indices = []
|
||||
latents = []
|
||||
|
||||
if n_quantizers is None:
|
||||
n_quantizers = self.n_codebooks
|
||||
if self.training:
|
||||
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
||||
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
||||
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
||||
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
||||
n_quantizers = n_quantizers.to(z.device)
|
||||
|
||||
for i, quantizer in enumerate(self.quantizers):
|
||||
if self.training is False and i >= n_quantizers:
|
||||
break
|
||||
|
||||
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
||||
residual
|
||||
)
|
||||
|
||||
# Create mask to apply quantizer dropout
|
||||
mask = (
|
||||
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
||||
)
|
||||
z_q = z_q + z_q_i * mask[:, None, None]
|
||||
residual = residual - z_q_i
|
||||
|
||||
# Sum losses
|
||||
commitment_loss += (commitment_loss_i * mask).mean()
|
||||
codebook_loss += (codebook_loss_i * mask).mean()
|
||||
|
||||
codebook_indices.append(indices_i)
|
||||
latents.append(z_e_i)
|
||||
|
||||
codes = torch.stack(codebook_indices, dim=1)
|
||||
latents = torch.cat(latents, dim=1)
|
||||
|
||||
return z_q, codes, latents, commitment_loss, codebook_loss
|
||||
|
||||
def from_codes(self, codes: torch.Tensor):
|
||||
"""Given the quantized codes, reconstruct the continuous representation
|
||||
Parameters
|
||||
----------
|
||||
codes : Tensor[B x N x T]
|
||||
Quantized discrete representation of input
|
||||
Returns
|
||||
-------
|
||||
Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"""
|
||||
z_q = 0.0
|
||||
z_p = []
|
||||
n_codebooks = codes.shape[1]
|
||||
for i in range(n_codebooks):
|
||||
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
||||
z_p.append(z_p_i)
|
||||
|
||||
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
||||
z_q = z_q + z_q_i
|
||||
return z_q, torch.cat(z_p, dim=1), codes
|
||||
|
||||
def from_latents(self, latents: torch.Tensor):
|
||||
"""Given the unquantized latents, reconstruct the
|
||||
continuous representation after quantization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
latents : Tensor[B x N x T]
|
||||
Continuous representation of input after projection
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor[B x D x T]
|
||||
Quantized representation of full-projected space
|
||||
Tensor[B x D x T]
|
||||
Quantized representation of latent space
|
||||
"""
|
||||
z_q = 0
|
||||
z_p = []
|
||||
codes = []
|
||||
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
||||
|
||||
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
||||
0
|
||||
]
|
||||
for i in range(n_codebooks):
|
||||
j, k = dims[i], dims[i + 1]
|
||||
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
||||
z_p.append(z_p_i)
|
||||
codes.append(codes_i)
|
||||
|
||||
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
||||
z_q = z_q + z_q_i
|
||||
|
||||
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.mean(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2],
|
||||
)
|
||||
else:
|
||||
return 0.5 * torch.mean(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var
|
||||
- 1.0
|
||||
- self.logvar
|
||||
+ other.logvar,
|
||||
dim=[1, 2],
|
||||
)
|
||||
|
||||
def nll(self, sample, dims=[1, 2]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims,
|
||||
)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(self, dim: int = 16, dilation: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(dim),
|
||||
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
||||
Snake1d(dim),
|
||||
WNConv1d(dim, dim, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.block(x)
|
||||
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||
if pad > 0:
|
||||
x = x[..., pad:-pad]
|
||||
return x + y
|
||||
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, dim: int = 16, stride: int = 1):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
ResidualUnit(dim // 2, dilation=1),
|
||||
ResidualUnit(dim // 2, dilation=3),
|
||||
ResidualUnit(dim // 2, dilation=9),
|
||||
Snake1d(dim // 2),
|
||||
WNConv1d(
|
||||
dim // 2,
|
||||
dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 64,
|
||||
strides: list = [2, 4, 8, 8],
|
||||
d_latent: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
# Create first convolution
|
||||
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in strides:
|
||||
d_model *= 2
|
||||
self.block += [EncoderBlock(d_model, stride=stride)]
|
||||
|
||||
# Create last convolution
|
||||
self.block += [
|
||||
Snake1d(d_model),
|
||||
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
|
||||
]
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(input_dim),
|
||||
WNConvTranspose1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
output_padding=stride % 2,
|
||||
),
|
||||
ResidualUnit(output_dim, dilation=1),
|
||||
ResidualUnit(output_dim, dilation=3),
|
||||
ResidualUnit(output_dim, dilation=9),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
channels,
|
||||
rates,
|
||||
d_out: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add first conv layer
|
||||
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
||||
|
||||
# Add upsampling + MRF blocks
|
||||
for i, stride in enumerate(rates):
|
||||
input_dim = channels // 2**i
|
||||
output_dim = channels // 2 ** (i + 1)
|
||||
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
||||
|
||||
# Add final conv layer
|
||||
layers += [
|
||||
Snake1d(output_dim),
|
||||
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class DacVAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int = 128,
|
||||
encoder_rates: List[int] = [2, 3, 4, 5, 8],
|
||||
latent_dim: int = 128,
|
||||
decoder_dim: int = 2048,
|
||||
decoder_rates: List[int] = [8, 5, 4, 3, 2],
|
||||
n_codebooks: int = 9,
|
||||
codebook_size: int = 1024,
|
||||
codebook_dim: Union[int, list] = 8,
|
||||
quantizer_dropout: bool = False,
|
||||
sample_rate: int = 48000,
|
||||
continuous: bool = True,
|
||||
use_weight_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
self.decoder_rates = decoder_rates
|
||||
self.sample_rate = sample_rate
|
||||
self.continuous = continuous
|
||||
self.use_weight_norm = use_weight_norm
|
||||
|
||||
if latent_dim is None:
|
||||
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.hop_length = np.prod(encoder_rates)
|
||||
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
|
||||
|
||||
if not continuous:
|
||||
self.n_codebooks = n_codebooks
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
self.quantizer = ResidualVectorQuantize(
|
||||
input_dim=latent_dim,
|
||||
n_codebooks=n_codebooks,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
quantizer_dropout=quantizer_dropout,
|
||||
)
|
||||
else:
|
||||
self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
|
||||
|
||||
self.decoder = Decoder(
|
||||
latent_dim,
|
||||
decoder_dim,
|
||||
decoder_rates,
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.apply(init_weights)
|
||||
|
||||
self.delay = self.get_delay()
|
||||
|
||||
if not self.use_weight_norm:
|
||||
self.remove_weight_norm()
|
||||
|
||||
def get_delay(self):
|
||||
# Any number works here, delay is invariant to input length
|
||||
l_out = self.get_output_length(0)
|
||||
L = l_out
|
||||
|
||||
layers = []
|
||||
for layer in self.modules():
|
||||
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
||||
layers.append(layer)
|
||||
|
||||
for layer in reversed(layers):
|
||||
d = layer.dilation[0]
|
||||
k = layer.kernel_size[0]
|
||||
s = layer.stride[0]
|
||||
|
||||
if isinstance(layer, nn.ConvTranspose1d):
|
||||
L = ((L - d * (k - 1) - 1) / s) + 1
|
||||
elif isinstance(layer, nn.Conv1d):
|
||||
L = (L - 1) * s + d * (k - 1) + 1
|
||||
|
||||
L = math.ceil(L)
|
||||
|
||||
l_in = L
|
||||
|
||||
return (l_in - l_out) // 2
|
||||
|
||||
def get_output_length(self, input_length):
|
||||
L = input_length
|
||||
# Calculate output length
|
||||
for layer in self.modules():
|
||||
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
||||
d = layer.dilation[0]
|
||||
k = layer.kernel_size[0]
|
||||
s = layer.stride[0]
|
||||
|
||||
if isinstance(layer, nn.Conv1d):
|
||||
L = ((L - d * (k - 1) - 1) / s) + 1
|
||||
elif isinstance(layer, nn.ConvTranspose1d):
|
||||
L = (L - 1) * s + d * (k - 1) + 1
|
||||
|
||||
L = math.floor(L)
|
||||
return L
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""Get the dtype of the model parameters."""
|
||||
# Return the dtype of the first parameter found
|
||||
for param in self.parameters():
|
||||
return param.dtype
|
||||
return torch.float32 # fallback
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""Get the device of the model parameters."""
|
||||
# Return the device of the first parameter found
|
||||
for param in self.parameters():
|
||||
return param.device
|
||||
return torch.device('cpu') # fallback
|
||||
|
||||
def preprocess(self, audio_data, sample_rate):
|
||||
if sample_rate is None:
|
||||
sample_rate = self.sample_rate
|
||||
assert sample_rate == self.sample_rate
|
||||
|
||||
length = audio_data.shape[-1]
|
||||
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
||||
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
||||
|
||||
return audio_data
|
||||
|
||||
def encode(
|
||||
self,
|
||||
audio_data: torch.Tensor,
|
||||
n_quantizers: int = None,
|
||||
):
|
||||
"""Encode given audio data and return quantized latent codes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_data : Tensor[B x 1 x T]
|
||||
Audio data to encode
|
||||
n_quantizers : int, optional
|
||||
Number of quantizers to use, by default None
|
||||
If None, all quantizers are used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"z" : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"codes" : Tensor[B x N x T]
|
||||
Codebook indices for each codebook
|
||||
(quantized discrete representation of input)
|
||||
"latents" : Tensor[B x N*D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"vq/commitment_loss" : Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
"vq/codebook_loss" : Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
"length" : int
|
||||
Number of samples in input audio
|
||||
"""
|
||||
z = self.encoder(audio_data) # [B x D x T]
|
||||
if not self.continuous:
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
|
||||
else:
|
||||
z = self.quant_conv(z) # [B x 2D x T]
|
||||
z = DiagonalGaussianDistribution(z)
|
||||
codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
|
||||
|
||||
return z, codes, latents, commitment_loss, codebook_loss
|
||||
|
||||
def decode(self, z: torch.Tensor):
|
||||
"""Decode given latent codes and return audio data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
length : int, optional
|
||||
Number of samples in output audio, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
if not self.continuous:
|
||||
audio = self.decoder(z)
|
||||
else:
|
||||
z = self.post_quant_conv(z)
|
||||
audio = self.decoder(z)
|
||||
|
||||
return audio
|
||||
|
||||
def forward(
|
||||
self,
|
||||
audio_data: torch.Tensor,
|
||||
sample_rate: int = None,
|
||||
n_quantizers: int = None,
|
||||
):
|
||||
"""Model forward pass
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_data : Tensor[B x 1 x T]
|
||||
Audio data to encode
|
||||
sample_rate : int, optional
|
||||
Sample rate of audio data in Hz, by default None
|
||||
If None, defaults to `self.sample_rate`
|
||||
n_quantizers : int, optional
|
||||
Number of quantizers to use, by default None.
|
||||
If None, all quantizers are used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"z" : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"codes" : Tensor[B x N x T]
|
||||
Codebook indices for each codebook
|
||||
(quantized discrete representation of input)
|
||||
"latents" : Tensor[B x N*D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"vq/commitment_loss" : Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
"vq/codebook_loss" : Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
"length" : int
|
||||
Number of samples in input audio
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
length = audio_data.shape[-1]
|
||||
audio_data = self.preprocess(audio_data, sample_rate)
|
||||
if not self.continuous:
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
|
||||
|
||||
x = self.decode(z)
|
||||
return {
|
||||
"audio": x[..., :length],
|
||||
"z": z,
|
||||
"codes": codes,
|
||||
"latents": latents,
|
||||
"vq/commitment_loss": commitment_loss,
|
||||
"vq/codebook_loss": codebook_loss,
|
||||
}
|
||||
else:
|
||||
posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
|
||||
z = posterior.sample()
|
||||
x = self.decode(z)
|
||||
|
||||
kl_loss = posterior.kl()
|
||||
kl_loss = kl_loss.mean()
|
||||
|
||||
return {
|
||||
"audio": x[..., :length],
|
||||
"z": z,
|
||||
"kl_loss": kl_loss,
|
||||
}
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""
|
||||
Remove weight_norm from all modules in the model.
|
||||
This fuses the weight_g and weight_v parameters into a single weight parameter.
|
||||
Should be called before inference for better performance.
|
||||
Returns:
|
||||
self: The model with weight_norm removed
|
||||
"""
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
num_removed = 0
|
||||
for name, module in list(self.named_modules()):
|
||||
if hasattr(module, "_forward_pre_hooks"):
|
||||
for hook_id, hook in list(module._forward_pre_hooks.items()):
|
||||
if "WeightNorm" in str(type(hook)):
|
||||
try:
|
||||
remove_weight_norm(module)
|
||||
num_removed += 1
|
||||
# print(f"Removed weight_norm from: {name}")
|
||||
except ValueError as e:
|
||||
print(f"Failed to remove weight_norm from {name}: {e}")
|
||||
if num_removed > 0:
|
||||
# print(f"Successfully removed weight_norm from {num_removed} modules")
|
||||
self.use_weight_norm = False
|
||||
else:
|
||||
print("No weight_norm found in the model")
|
||||
return self
|
||||
595
diffsynth/models/mova_dual_tower_bridge.py
Normal file
595
diffsynth/models/mova_dual_tower_bridge.py
Normal file
@@ -0,0 +1,595 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .wan_video_dit import AttentionModule, RMSNorm
|
||||
from ..core import gradient_checkpoint_forward
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||
|
||||
def __init__(self, base: float, dim: int, device=None):
|
||||
super().__init__()
|
||||
self.base = base
|
||||
self.dim = dim
|
||||
self.attention_scaling = 1.0
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class PerFrameAttentionPooling(nn.Module):
|
||||
"""
|
||||
Per-frame multi-head attention pooling.
|
||||
|
||||
Given a flattened token sequence [B, L, D] and grid size (T, H, W), perform a
|
||||
single-query attention pooling over the H*W tokens for each time frame, producing
|
||||
[B, T, D].
|
||||
|
||||
Inspired by SigLIP's Multihead Attention Pooling head (without MLP/residual stack).
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.probe = nn.Parameter(torch.randn(1, 1, dim))
|
||||
nn.init.normal_(self.probe, std=0.02)
|
||||
|
||||
self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
|
||||
self.layernorm = nn.LayerNorm(dim, eps=eps)
|
||||
|
||||
def forward(self, x: torch.Tensor, grid_size: Tuple[int, int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [B, L, D], where L = T*H*W
|
||||
grid_size: (T, H, W)
|
||||
Returns:
|
||||
pooled: [B, T, D]
|
||||
"""
|
||||
B, L, D = x.shape
|
||||
T, H, W = grid_size
|
||||
assert D == self.dim, f"Channel dimension mismatch: D={D} vs dim={self.dim}"
|
||||
assert L == T * H * W, f"Flattened length mismatch: L={L} vs T*H*W={T*H*W}"
|
||||
|
||||
S = H * W
|
||||
# Re-arrange tokens grouped by frame.
|
||||
x_bt_s_d = x.view(B, T, S, D).contiguous().view(B * T, S, D) # [B*T, S, D]
|
||||
|
||||
# A learnable probe as the query (one query per frame).
|
||||
probe = self.probe.expand(B * T, -1, -1) # [B*T, 1, D]
|
||||
|
||||
# Attention pooling: query=probe, key/value=H*W tokens within the frame.
|
||||
pooled_bt_1_d = self.attention(probe, x_bt_s_d, x_bt_s_d, need_weights=False)[0] # [B*T, 1, D]
|
||||
pooled_bt_d = pooled_bt_1_d.squeeze(1) # [B*T, D]
|
||||
|
||||
# Restore to [B, T, D].
|
||||
pooled = pooled_bt_d.view(B, T, D)
|
||||
pooled = self.layernorm(pooled)
|
||||
return pooled
|
||||
|
||||
|
||||
class CrossModalInteractionController:
|
||||
"""
|
||||
Strategy class that controls interactions between two towers.
|
||||
Manages the interaction mapping between visual DiT (e.g. 30 layers) and audio DiT (e.g. 30 layers).
|
||||
"""
|
||||
|
||||
def __init__(self, visual_layers: int = 30, audio_layers: int = 30):
|
||||
self.visual_layers = visual_layers
|
||||
self.audio_layers = audio_layers
|
||||
self.min_layers = min(visual_layers, audio_layers)
|
||||
|
||||
def get_interaction_layers(self, strategy: str = "shallow_focus") -> Dict[str, List[Tuple[int, int]]]:
|
||||
"""
|
||||
Get interaction layer mappings.
|
||||
|
||||
Args:
|
||||
strategy: interaction strategy
|
||||
- "shallow_focus": emphasize shallow layers to avoid deep-layer asymmetry
|
||||
- "distributed": distributed interactions across the network
|
||||
- "progressive": dense shallow interactions, sparse deeper interactions
|
||||
- "custom": custom interaction layers
|
||||
|
||||
Returns:
|
||||
A dict containing mappings for 'v2a' (visual -> audio) and 'a2v' (audio -> visual).
|
||||
"""
|
||||
|
||||
if strategy == "shallow_focus":
|
||||
# Emphasize the first ~1/3 layers to avoid deep-layer asymmetry.
|
||||
num_interact = min(10, self.min_layers // 3)
|
||||
interact_layers = list(range(0, num_interact))
|
||||
|
||||
elif strategy == "distributed":
|
||||
# Distribute interactions across the network (every few layers).
|
||||
step = 3
|
||||
interact_layers = list(range(0, self.min_layers, step))
|
||||
|
||||
elif strategy == "progressive":
|
||||
# Progressive: dense shallow interactions, sparse deeper interactions.
|
||||
shallow = list(range(0, min(8, self.min_layers))) # Dense for the first 8 layers.
|
||||
if self.min_layers > 8:
|
||||
deep = list(range(8, self.min_layers, 3)) # Every 3 layers afterwards.
|
||||
interact_layers = shallow + deep
|
||||
else:
|
||||
interact_layers = shallow
|
||||
|
||||
elif strategy == "custom":
|
||||
# Custom strategy: adjust as needed.
|
||||
interact_layers = [0, 2, 4, 6, 8, 12, 16, 20] # Explicit layer indices.
|
||||
interact_layers = [i for i in interact_layers if i < self.min_layers]
|
||||
|
||||
elif strategy == "full":
|
||||
interact_layers = list(range(0, self.min_layers))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown interaction strategy: {strategy}")
|
||||
|
||||
# Build bidirectional mapping.
|
||||
mapping = {
|
||||
'v2a': [(i, i) for i in interact_layers], # visual layer i -> audio layer i
|
||||
'a2v': [(i, i) for i in interact_layers] # audio layer i -> visual layer i
|
||||
}
|
||||
|
||||
return mapping
|
||||
|
||||
def should_interact(self, layer_idx: int, direction: str, interaction_mapping: Dict) -> bool:
|
||||
"""
|
||||
Check whether a given layer should interact.
|
||||
|
||||
Args:
|
||||
layer_idx: current layer index
|
||||
direction: interaction direction ('v2a' or 'a2v')
|
||||
interaction_mapping: interaction mapping table
|
||||
|
||||
Returns:
|
||||
bool: whether to interact
|
||||
"""
|
||||
if direction not in interaction_mapping:
|
||||
return False
|
||||
|
||||
return any(src == layer_idx for src, _ in interaction_mapping[direction])
|
||||
|
||||
|
||||
class ConditionalCrossAttention(nn.Module):
|
||||
def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.q_dim = dim
|
||||
self.kv_dim = kv_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.q_dim // num_heads
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(kv_dim, dim)
|
||||
self.v = nn.Linear(kv_dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = RMSNorm(dim, eps=eps)
|
||||
self.norm_k = RMSNorm(dim, eps=eps)
|
||||
|
||||
self.attn = AttentionModule(self.num_heads)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor, x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
|
||||
ctx = y
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(ctx))
|
||||
v = self.v(ctx)
|
||||
if x_freqs is not None:
|
||||
x_cos, x_sin = x_freqs
|
||||
B, L, _ = q.shape
|
||||
q_view = rearrange(q, 'b l (h d) -> b l h d', d=self.head_dim)
|
||||
x_cos = x_cos.to(q_view.dtype).to(q_view.device)
|
||||
x_sin = x_sin.to(q_view.dtype).to(q_view.device)
|
||||
# Expect x_cos/x_sin shape: [B or 1, L, head_dim]
|
||||
q_view, _ = apply_rotary_pos_emb(q_view, q_view, x_cos, x_sin, unsqueeze_dim=2)
|
||||
q = rearrange(q_view, 'b l h d -> b l (h d)')
|
||||
if y_freqs is not None:
|
||||
y_cos, y_sin = y_freqs
|
||||
Bc, Lc, _ = k.shape
|
||||
k_view = rearrange(k, 'b l (h d) -> b l h d', d=self.head_dim)
|
||||
y_cos = y_cos.to(k_view.dtype).to(k_view.device)
|
||||
y_sin = y_sin.to(k_view.dtype).to(k_view.device)
|
||||
# Expect y_cos/y_sin shape: [B or 1, L, head_dim]
|
||||
_, k_view = apply_rotary_pos_emb(k_view, k_view, y_cos, y_sin, unsqueeze_dim=2)
|
||||
k = rearrange(k_view, 'b l h d -> b l (h d)')
|
||||
x = self.attn(q, k, v)
|
||||
return self.o(x)
|
||||
|
||||
|
||||
# from diffusers.models.attention import AdaLayerNorm
|
||||
class AdaLayerNorm(nn.Module):
|
||||
r"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
||||
output_dim (`int`, *optional*):
|
||||
norm_elementwise_affine (`bool`, defaults to `False):
|
||||
norm_eps (`bool`, defaults to `False`):
|
||||
chunk_dim (`int`, defaults to `0`):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_embeddings: Optional[int] = None,
|
||||
output_dim: Optional[int] = None,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
chunk_dim: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = output_dim or embedding_dim * 2
|
||||
|
||||
if num_embeddings is not None:
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if self.emb is not None:
|
||||
temb = self.emb(timestep)
|
||||
|
||||
temb = self.linear(self.silu(temb))
|
||||
|
||||
if self.chunk_dim == 2:
|
||||
scale, shift = temb.chunk(2, dim=2)
|
||||
# print(f"{x.shape = }, {scale.shape = }, {shift.shape = }")
|
||||
elif self.chunk_dim == 1:
|
||||
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
||||
# other if-branch. This branch is specific to CogVideoX and OmniGen for now.
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
else:
|
||||
scale, shift = temb.chunk(2, dim=0)
|
||||
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class ConditionalCrossAttentionBlock(nn.Module):
|
||||
"""
|
||||
A thin wrapper around ConditionalCrossAttention.
|
||||
Applies LayerNorm to the conditioning input `y` before cross-attention.
|
||||
"""
|
||||
def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6, pooled_adaln: bool = False):
|
||||
super().__init__()
|
||||
self.y_norm = nn.LayerNorm(kv_dim, eps=eps)
|
||||
self.inner = ConditionalCrossAttention(dim=dim, kv_dim=kv_dim, num_heads=num_heads, eps=eps)
|
||||
self.pooled_adaln = pooled_adaln
|
||||
if pooled_adaln:
|
||||
self.per_frame_pooling = PerFrameAttentionPooling(kv_dim, num_heads=num_heads, eps=eps)
|
||||
self.adaln = AdaLayerNorm(kv_dim, output_dim=dim*2, chunk_dim=2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
video_grid_size: Optional[Tuple[int, int, int]] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.pooled_adaln:
|
||||
assert video_grid_size is not None, "video_grid_size must not be None"
|
||||
pooled_y = self.per_frame_pooling(y, video_grid_size)
|
||||
# Interpolate pooled_y along its temporal dimension to match x's sequence length.
|
||||
if pooled_y.shape[1] != x.shape[1]:
|
||||
pooled_y = F.interpolate(
|
||||
pooled_y.permute(0, 2, 1), # [B, C, T]
|
||||
size=x.shape[1],
|
||||
mode='linear',
|
||||
align_corners=False,
|
||||
).permute(0, 2, 1) # [B, T, C]
|
||||
x = self.adaln(x, temb=pooled_y)
|
||||
y = self.y_norm(y)
|
||||
return self.inner(x=x, y=y, x_freqs=x_freqs, y_freqs=y_freqs)
|
||||
|
||||
|
||||
class DualTowerConditionalBridge(nn.Module):
|
||||
"""
|
||||
Dual-tower conditional bridge.
|
||||
"""
|
||||
def __init__(self,
|
||||
visual_layers: int = 40,
|
||||
audio_layers: int = 30,
|
||||
visual_hidden_dim: int = 5120, # visual DiT hidden state dimension
|
||||
audio_hidden_dim: int = 1536, # audio DiT hidden state dimension
|
||||
audio_fps: float = 50.0,
|
||||
head_dim: int = 128, # attention head dimension
|
||||
interaction_strategy: str = "full",
|
||||
apply_cross_rope: bool = True, # whether to apply RoPE in cross-attention
|
||||
apply_first_frame_bias_in_rope: bool = False, # whether to account for 1/video_fps bias for the first frame in RoPE alignment
|
||||
trainable_condition_scale: bool = False,
|
||||
pooled_adaln: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.visual_hidden_dim = visual_hidden_dim
|
||||
self.audio_hidden_dim = audio_hidden_dim
|
||||
self.audio_fps = audio_fps
|
||||
self.head_dim = head_dim
|
||||
self.apply_cross_rope = apply_cross_rope
|
||||
self.apply_first_frame_bias_in_rope = apply_first_frame_bias_in_rope
|
||||
self.trainable_condition_scale = trainable_condition_scale
|
||||
self.pooled_adaln = pooled_adaln
|
||||
if self.trainable_condition_scale:
|
||||
self.condition_scale = nn.Parameter(torch.tensor([1.0], dtype=torch.float32))
|
||||
else:
|
||||
self.condition_scale = 1.0
|
||||
|
||||
self.controller = CrossModalInteractionController(visual_layers, audio_layers)
|
||||
self.interaction_mapping = self.controller.get_interaction_layers(interaction_strategy)
|
||||
|
||||
# Conditional cross-attention modules operating at the DiT hidden-state level.
|
||||
self.audio_to_video_conditioners = nn.ModuleDict() # audio hidden states -> visual DiT conditioning
|
||||
self.video_to_audio_conditioners = nn.ModuleDict() # visual hidden states -> audio DiT conditioning
|
||||
|
||||
# Build conditioners for layers that should interact.
|
||||
# audio hidden states condition the visual DiT
|
||||
self.rotary = RotaryEmbedding(base=10000.0, dim=head_dim)
|
||||
for v_layer, _ in self.interaction_mapping['a2v']:
|
||||
self.audio_to_video_conditioners[str(v_layer)] = ConditionalCrossAttentionBlock(
|
||||
dim=visual_hidden_dim, # 3072 (visual DiT hidden states)
|
||||
kv_dim=audio_hidden_dim, # 1536 (audio DiT hidden states)
|
||||
num_heads=visual_hidden_dim // head_dim, # derive number of heads from hidden dim
|
||||
pooled_adaln=False # a2v typically does not need pooled AdaLN
|
||||
)
|
||||
|
||||
# visual hidden states condition the audio DiT
|
||||
for a_layer, _ in self.interaction_mapping['v2a']:
|
||||
self.video_to_audio_conditioners[str(a_layer)] = ConditionalCrossAttentionBlock(
|
||||
dim=audio_hidden_dim, # 1536 (audio DiT hidden states)
|
||||
kv_dim=visual_hidden_dim, # 3072 (visual DiT hidden states)
|
||||
num_heads=audio_hidden_dim // head_dim, # safe head count derivation
|
||||
pooled_adaln=self.pooled_adaln
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def build_aligned_freqs(self,
|
||||
video_fps: float,
|
||||
grid_size: Tuple[int, int, int],
|
||||
audio_steps: int,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Build aligned RoPE (cos, sin) pairs based on video fps, video grid size (f_v, h, w),
|
||||
and audio sequence length `audio_steps` (with fixed audio fps = 44100/2048).
|
||||
|
||||
Returns:
|
||||
visual_freqs: (cos_v, sin_v), shape [1, f_v*h*w, head_dim]
|
||||
audio_freqs: (cos_a, sin_a), shape [1, audio_steps, head_dim]
|
||||
"""
|
||||
f_v, h, w = grid_size
|
||||
L_v = f_v * h * w
|
||||
L_a = int(audio_steps)
|
||||
|
||||
device = device or next(self.parameters()).device
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
# Audio positions: 0,1,2,...,L_a-1 (audio as reference).
|
||||
audio_pos = torch.arange(L_a, device=device, dtype=torch.float32).unsqueeze(0)
|
||||
|
||||
# Video positions: align video frames to audio-step units.
|
||||
# FIXME(dhyu): hard-coded VAE temporal stride = 4
|
||||
if self.apply_first_frame_bias_in_rope:
|
||||
# Account for the "first frame lasts 1/video_fps" bias.
|
||||
video_effective_fps = float(video_fps) / 4.0
|
||||
if f_v > 0:
|
||||
t_starts = torch.zeros((f_v,), device=device, dtype=torch.float32)
|
||||
if f_v > 1:
|
||||
t_starts[1:] = (1.0 / float(video_fps)) + torch.arange(f_v - 1, device=device, dtype=torch.float32) * (1.0 / video_effective_fps)
|
||||
else:
|
||||
t_starts = torch.zeros((0,), device=device, dtype=torch.float32)
|
||||
# Convert to audio-step units.
|
||||
video_pos_per_frame = t_starts * float(self.audio_fps)
|
||||
else:
|
||||
# No first-frame bias: uniform alignment.
|
||||
scale = float(self.audio_fps) / float(video_fps / 4.0)
|
||||
video_pos_per_frame = torch.arange(f_v, device=device, dtype=torch.float32) * scale
|
||||
# Flatten to f*h*w; tokens within the same frame share the same time position.
|
||||
video_pos = video_pos_per_frame.repeat_interleave(h * w).unsqueeze(0)
|
||||
|
||||
# print(f"video fps: {video_fps}, audio fps: {self.audio_fps}, scale: {scale}")
|
||||
# print(f"video pos: {video_pos.shape}, audio pos: {audio_pos.shape}")
|
||||
|
||||
# Build dummy x to produce cos/sin, dim=head_dim.
|
||||
dummy_v = torch.zeros((1, L_v, self.head_dim), device=device, dtype=dtype)
|
||||
dummy_a = torch.zeros((1, L_a, self.head_dim), device=device, dtype=dtype)
|
||||
|
||||
cos_v, sin_v = self.rotary(dummy_v, position_ids=video_pos)
|
||||
cos_a, sin_a = self.rotary(dummy_a, position_ids=audio_pos)
|
||||
|
||||
return (cos_v, sin_v), (cos_a, sin_a)
|
||||
|
||||
def should_interact(self, layer_idx: int, direction: str) -> bool:
|
||||
return self.controller.should_interact(layer_idx, direction, self.interaction_mapping)
|
||||
|
||||
def apply_conditional_control(
|
||||
self,
|
||||
layer_idx: int,
|
||||
direction: str,
|
||||
primary_hidden_states: torch.Tensor,
|
||||
condition_hidden_states: torch.Tensor,
|
||||
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
condition_scale: Optional[float] = None,
|
||||
video_grid_size: Optional[Tuple[int, int, int]] = None,
|
||||
use_gradient_checkpointing: Optional[bool] = False,
|
||||
use_gradient_checkpointing_offload: Optional[bool] = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply conditional control (at the DiT hidden-state level).
|
||||
|
||||
Args:
|
||||
layer_idx: current layer index
|
||||
direction: conditioning direction
|
||||
- 'a2v': audio hidden states -> visual DiT
|
||||
- 'v2a': visual hidden states -> audio DiT
|
||||
primary_hidden_states: primary DiT hidden states [B, L, hidden_dim]
|
||||
condition_hidden_states: condition DiT hidden states [B, L, hidden_dim]
|
||||
condition_scale: conditioning strength (similar to CFG scale)
|
||||
|
||||
Returns:
|
||||
Conditioned primary DiT hidden states [B, L, hidden_dim]
|
||||
"""
|
||||
|
||||
if not self.controller.should_interact(layer_idx, direction, self.interaction_mapping):
|
||||
return primary_hidden_states
|
||||
|
||||
if direction == 'a2v':
|
||||
# audio hidden states condition the visual DiT
|
||||
conditioner = self.audio_to_video_conditioners[str(layer_idx)]
|
||||
|
||||
elif direction == 'v2a':
|
||||
# visual hidden states condition the audio DiT
|
||||
conditioner = self.video_to_audio_conditioners[str(layer_idx)]
|
||||
else:
|
||||
raise ValueError(f"Invalid direction: {direction}")
|
||||
|
||||
conditioned_features = gradient_checkpoint_forward(
|
||||
conditioner,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x=primary_hidden_states,
|
||||
y=condition_hidden_states,
|
||||
x_freqs=x_freqs,
|
||||
y_freqs=y_freqs,
|
||||
video_grid_size=video_grid_size,
|
||||
)
|
||||
|
||||
if self.trainable_condition_scale and condition_scale is not None:
|
||||
print(
|
||||
"[WARN] This model has a trainable condition_scale, but an external "
|
||||
f"condition_scale={condition_scale} was provided. The trainable condition_scale "
|
||||
"will be ignored in favor of the external value."
|
||||
)
|
||||
|
||||
scale = condition_scale if condition_scale is not None else self.condition_scale
|
||||
|
||||
primary_hidden_states = primary_hidden_states + conditioned_features * scale
|
||||
|
||||
return primary_hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer_idx: int,
|
||||
visual_hidden_states: torch.Tensor,
|
||||
audio_hidden_states: torch.Tensor,
|
||||
*,
|
||||
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
a2v_condition_scale: Optional[float] = None,
|
||||
v2a_condition_scale: Optional[float] = None,
|
||||
condition_scale: Optional[float] = None,
|
||||
video_grid_size: Optional[Tuple[int, int, int]] = None,
|
||||
use_gradient_checkpointing: Optional[bool] = False,
|
||||
use_gradient_checkpointing_offload: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply bidirectional conditional control to both visual/audio towers.
|
||||
|
||||
Args:
|
||||
layer_idx: current layer index
|
||||
visual_hidden_states: visual DiT hidden states
|
||||
audio_hidden_states: audio DiT hidden states
|
||||
x_freqs / y_freqs: cross-modal RoPE (cos, sin) pairs.
|
||||
If provided, x_freqs is assumed to correspond to the primary tower and y_freqs
|
||||
to the conditioning tower.
|
||||
a2v_condition_scale: audio->visual conditioning strength (overrides global condition_scale)
|
||||
v2a_condition_scale: visual->audio conditioning strength (overrides global condition_scale)
|
||||
condition_scale: fallback conditioning strength when per-direction scale is None
|
||||
video_grid_size: (F, H, W), used on the audio side when pooled_adaln is enabled
|
||||
|
||||
Returns:
|
||||
(visual_hidden_states, audio_hidden_states), both conditioned in their respective directions.
|
||||
"""
|
||||
|
||||
visual_conditioned = self.apply_conditional_control(
|
||||
layer_idx=layer_idx,
|
||||
direction="a2v",
|
||||
primary_hidden_states=visual_hidden_states,
|
||||
condition_hidden_states=audio_hidden_states,
|
||||
x_freqs=x_freqs,
|
||||
y_freqs=y_freqs,
|
||||
condition_scale=a2v_condition_scale if a2v_condition_scale is not None else condition_scale,
|
||||
video_grid_size=video_grid_size,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
audio_conditioned = self.apply_conditional_control(
|
||||
layer_idx=layer_idx,
|
||||
direction="v2a",
|
||||
primary_hidden_states=audio_hidden_states,
|
||||
condition_hidden_states=visual_hidden_states,
|
||||
x_freqs=y_freqs,
|
||||
y_freqs=x_freqs,
|
||||
condition_scale=v2a_condition_scale if v2a_condition_scale is not None else condition_scale,
|
||||
video_grid_size=video_grid_size,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
return visual_conditioned, audio_conditioned
|
||||
@@ -99,18 +99,30 @@ def rope_apply(x, freqs, num_heads):
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
|
||||
def set_to_torch_norm(models):
|
||||
for model in models:
|
||||
for module in model.modules():
|
||||
if isinstance(module, RMSNorm):
|
||||
module.use_torch_norm = True
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.use_torch_norm = False
|
||||
self.normalized_shape = (dim,)
|
||||
|
||||
def norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
dtype = x.dtype
|
||||
return self.norm(x.float()).to(dtype) * self.weight
|
||||
if self.use_torch_norm:
|
||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
else:
|
||||
return self.norm(x.float()).to(dtype) * self.weight
|
||||
|
||||
|
||||
class AttentionModule(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user