* support mova inference

* mova media_io

* add unified audio_video api & fix bug of mono audio input for ltx

* support mova train

* mova docs

* fix bug
This commit is contained in:
Hong Zhang
2026-03-13 13:06:07 +08:00
committed by GitHub
parent 4741542523
commit 681df93a85
37 changed files with 3102 additions and 181 deletions

View File

@@ -0,0 +1,57 @@
import torch
import torch.nn as nn
from .wan_video_dit import WanModel, precompute_freqs_cis, sinusoidal_embedding_1d
from einops import rearrange
from ..core import gradient_checkpoint_forward
def precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0):
f_freqs_cis = precompute_freqs_cis(dim, end, theta)
return f_freqs_cis.chunk(3, dim=-1)
class MovaAudioDit(WanModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
head_dim = kwargs.get("dim", 1536) // kwargs.get("num_heads", 12)
self.freqs = precompute_freqs_cis_1d(head_dim)
self.patch_embedding = nn.Conv1d(
kwargs.get("in_dim", 128), kwargs.get("dim", 1536), kernel_size=[1], stride=[1]
)
def precompute_freqs_cis(self, dim: int, end: int = 16384, theta: float = 10000.0):
self.f_freqs_cis = precompute_freqs_cis_1d(dim, end, theta)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
x, (f, ) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, -1).expand(f, -1),
self.freqs[1][:f].view(f, -1).expand(f, -1),
self.freqs[2][:f].view(f, -1).expand(f, -1),
], dim=-1).reshape(f, 1, -1).to(x.device)
for block in self.blocks:
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, freqs,
)
x = self.head(x, t)
x = self.unpatchify(x, (f, ))
return x
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b f (p c) -> b c (f p)',
f=grid_size[0],
p=self.patch_size[0]
)

View File

@@ -0,0 +1,796 @@
import math
from typing import List, Union
import numpy as np
import torch
from torch import nn
from torch.nn.utils import weight_norm
import torch.nn.functional as F
from einops import rearrange
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
class VectorQuantize(nn.Module):
"""
Implementation of VQ similar to Karpathy's repo:
https://github.com/karpathy/deep-vector-quantization
Additionally uses following tricks from Improved VQGAN
(https://arxiv.org/pdf/2110.04627.pdf):
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
for improved codebook usage
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
improves training stability
"""
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
super().__init__()
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
self.codebook = nn.Embedding(codebook_size, codebook_dim)
def forward(self, z):
"""Quantized the input tensor using a fixed codebook and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
Tensor[1]
Codebook loss to update the codebook
Tensor[B x T]
Codebook indices (quantized discrete representation of input)
Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
z_e = self.in_proj(z) # z_e : (B x D x T)
z_q, indices = self.decode_latents(z_e)
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
z_q = (
z_e + (z_q - z_e).detach()
) # noop in forward pass, straight-through gradient estimator in backward pass
z_q = self.out_proj(z_q)
return z_q, commitment_loss, codebook_loss, indices, z_e
def embed_code(self, embed_id):
return F.embedding(embed_id, self.codebook.weight)
def decode_code(self, embed_id):
return self.embed_code(embed_id).transpose(1, 2)
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight # codebook: (N x D)
# L2 normalize encodings and codebook (ViT-VQGAN)
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance with codebook
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
z_q = self.decode_code(indices)
return z_q, indices
class ResidualVectorQuantize(nn.Module):
"""
Introduced in SoundStream: An end2end neural audio codec
https://arxiv.org/abs/2107.03312
"""
def __init__(
self,
input_dim: int = 512,
n_codebooks: int = 9,
codebook_size: int = 1024,
codebook_dim: Union[int, list] = 8,
quantizer_dropout: float = 0.0,
):
super().__init__()
if isinstance(codebook_dim, int):
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
self.n_codebooks = n_codebooks
self.codebook_dim = codebook_dim
self.codebook_size = codebook_size
self.quantizers = nn.ModuleList(
[
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
for i in range(n_codebooks)
]
)
self.quantizer_dropout = quantizer_dropout
def forward(self, z, n_quantizers: int = None):
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
n_quantizers : int, optional
No. of quantizers to use
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
Note: if `self.quantizer_dropout` is True, this argument is ignored
when in training mode, and a random number of quantizers is used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"""
z_q = 0
residual = z
commitment_loss = 0
codebook_loss = 0
codebook_indices = []
latents = []
if n_quantizers is None:
n_quantizers = self.n_codebooks
if self.training:
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
n_dropout = int(z.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(z.device)
for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
residual
)
# Create mask to apply quantizer dropout
mask = (
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
)
z_q = z_q + z_q_i * mask[:, None, None]
residual = residual - z_q_i
# Sum losses
commitment_loss += (commitment_loss_i * mask).mean()
codebook_loss += (codebook_loss_i * mask).mean()
codebook_indices.append(indices_i)
latents.append(z_e_i)
codes = torch.stack(codebook_indices, dim=1)
latents = torch.cat(latents, dim=1)
return z_q, codes, latents, commitment_loss, codebook_loss
def from_codes(self, codes: torch.Tensor):
"""Given the quantized codes, reconstruct the continuous representation
Parameters
----------
codes : Tensor[B x N x T]
Quantized discrete representation of input
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
"""
z_q = 0.0
z_p = []
n_codebooks = codes.shape[1]
for i in range(n_codebooks):
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
z_p.append(z_p_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), codes
def from_latents(self, latents: torch.Tensor):
"""Given the unquantized latents, reconstruct the
continuous representation after quantization.
Parameters
----------
latents : Tensor[B x N x T]
Continuous representation of input after projection
Returns
-------
Tensor[B x D x T]
Quantized representation of full-projected space
Tensor[B x D x T]
Quantized representation of latent space
"""
z_q = 0
z_p = []
codes = []
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
0
]
for i in range(n_codebooks):
j, k = dims[i], dims[i + 1]
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
z_p.append(z_p_i)
codes.append(codes_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.mean(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2],
)
else:
return 0.5 * torch.mean(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2],
)
def nll(self, sample, dims=[1, 2]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
return 0.5 * (
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
if pad > 0:
x = x[..., pad:-pad]
return x + y
class EncoderBlock(nn.Module):
def __init__(self, dim: int = 16, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
ResidualUnit(dim // 2, dilation=1),
ResidualUnit(dim // 2, dilation=3),
ResidualUnit(dim // 2, dilation=9),
Snake1d(dim // 2),
WNConv1d(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
)
def forward(self, x):
return self.block(x)
class Encoder(nn.Module):
def __init__(
self,
d_model: int = 64,
strides: list = [2, 4, 8, 8],
d_latent: int = 64,
):
super().__init__()
# Create first convolution
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
self.block += [EncoderBlock(d_model, stride=stride)]
# Create last convolution
self.block += [
Snake1d(d_model),
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
return self.block(x)
class DecoderBlock(nn.Module):
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
Snake1d(input_dim),
WNConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
ResidualUnit(output_dim, dilation=1),
ResidualUnit(output_dim, dilation=3),
ResidualUnit(output_dim, dilation=9),
)
def forward(self, x):
return self.block(x)
class Decoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
d_out: int = 1,
):
super().__init__()
# Add first conv layer
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DecoderBlock(input_dim, output_dim, stride)]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class DacVAE(nn.Module):
def __init__(
self,
encoder_dim: int = 128,
encoder_rates: List[int] = [2, 3, 4, 5, 8],
latent_dim: int = 128,
decoder_dim: int = 2048,
decoder_rates: List[int] = [8, 5, 4, 3, 2],
n_codebooks: int = 9,
codebook_size: int = 1024,
codebook_dim: Union[int, list] = 8,
quantizer_dropout: bool = False,
sample_rate: int = 48000,
continuous: bool = True,
use_weight_norm: bool = False,
):
super().__init__()
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.sample_rate = sample_rate
self.continuous = continuous
self.use_weight_norm = use_weight_norm
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
if not continuous:
self.n_codebooks = n_codebooks
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer = ResidualVectorQuantize(
input_dim=latent_dim,
n_codebooks=n_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
else:
self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
self.decoder = Decoder(
latent_dim,
decoder_dim,
decoder_rates,
)
self.sample_rate = sample_rate
self.apply(init_weights)
self.delay = self.get_delay()
if not self.use_weight_norm:
self.remove_weight_norm()
def get_delay(self):
# Any number works here, delay is invariant to input length
l_out = self.get_output_length(0)
L = l_out
layers = []
for layer in self.modules():
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
layers.append(layer)
for layer in reversed(layers):
d = layer.dilation[0]
k = layer.kernel_size[0]
s = layer.stride[0]
if isinstance(layer, nn.ConvTranspose1d):
L = ((L - d * (k - 1) - 1) / s) + 1
elif isinstance(layer, nn.Conv1d):
L = (L - 1) * s + d * (k - 1) + 1
L = math.ceil(L)
l_in = L
return (l_in - l_out) // 2
def get_output_length(self, input_length):
L = input_length
# Calculate output length
for layer in self.modules():
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
d = layer.dilation[0]
k = layer.kernel_size[0]
s = layer.stride[0]
if isinstance(layer, nn.Conv1d):
L = ((L - d * (k - 1) - 1) / s) + 1
elif isinstance(layer, nn.ConvTranspose1d):
L = (L - 1) * s + d * (k - 1) + 1
L = math.floor(L)
return L
@property
def dtype(self):
"""Get the dtype of the model parameters."""
# Return the dtype of the first parameter found
for param in self.parameters():
return param.dtype
return torch.float32 # fallback
@property
def device(self):
"""Get the device of the model parameters."""
# Return the device of the first parameter found
for param in self.parameters():
return param.device
return torch.device('cpu') # fallback
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
length = audio_data.shape[-1]
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def encode(
self,
audio_data: torch.Tensor,
n_quantizers: int = None,
):
"""Encode given audio data and return quantized latent codes
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
n_quantizers : int, optional
Number of quantizers to use, by default None
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"""
z = self.encoder(audio_data) # [B x D x T]
if not self.continuous:
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
else:
z = self.quant_conv(z) # [B x 2D x T]
z = DiagonalGaussianDistribution(z)
codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
return z, codes, latents, commitment_loss, codebook_loss
def decode(self, z: torch.Tensor):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
if not self.continuous:
audio = self.decoder(z)
else:
z = self.post_quant_conv(z)
audio = self.decoder(z)
return audio
def forward(
self,
audio_data: torch.Tensor,
sample_rate: int = None,
n_quantizers: int = None,
):
"""Model forward pass
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
sample_rate : int, optional
Sample rate of audio data in Hz, by default None
If None, defaults to `self.sample_rate`
n_quantizers : int, optional
Number of quantizers to use, by default None.
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
length = audio_data.shape[-1]
audio_data = self.preprocess(audio_data, sample_rate)
if not self.continuous:
z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
x = self.decode(z)
return {
"audio": x[..., :length],
"z": z,
"codes": codes,
"latents": latents,
"vq/commitment_loss": commitment_loss,
"vq/codebook_loss": codebook_loss,
}
else:
posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
z = posterior.sample()
x = self.decode(z)
kl_loss = posterior.kl()
kl_loss = kl_loss.mean()
return {
"audio": x[..., :length],
"z": z,
"kl_loss": kl_loss,
}
def remove_weight_norm(self):
"""
Remove weight_norm from all modules in the model.
This fuses the weight_g and weight_v parameters into a single weight parameter.
Should be called before inference for better performance.
Returns:
self: The model with weight_norm removed
"""
from torch.nn.utils import remove_weight_norm
num_removed = 0
for name, module in list(self.named_modules()):
if hasattr(module, "_forward_pre_hooks"):
for hook_id, hook in list(module._forward_pre_hooks.items()):
if "WeightNorm" in str(type(hook)):
try:
remove_weight_norm(module)
num_removed += 1
# print(f"Removed weight_norm from: {name}")
except ValueError as e:
print(f"Failed to remove weight_norm from {name}: {e}")
if num_removed > 0:
# print(f"Successfully removed weight_norm from {num_removed} modules")
self.use_weight_norm = False
else:
print("No weight_norm found in the model")
return self

View File

@@ -0,0 +1,595 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional
from einops import rearrange
from .wan_video_dit import AttentionModule, RMSNorm
from ..core import gradient_checkpoint_forward
class RotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, base: float, dim: int, device=None):
super().__init__()
self.base = base
self.dim = dim
self.attention_scaling = 1.0
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@torch.compile(fullgraph=True)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class PerFrameAttentionPooling(nn.Module):
"""
Per-frame multi-head attention pooling.
Given a flattened token sequence [B, L, D] and grid size (T, H, W), perform a
single-query attention pooling over the H*W tokens for each time frame, producing
[B, T, D].
Inspired by SigLIP's Multihead Attention Pooling head (without MLP/residual stack).
"""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.probe = nn.Parameter(torch.randn(1, 1, dim))
nn.init.normal_(self.probe, std=0.02)
self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
self.layernorm = nn.LayerNorm(dim, eps=eps)
def forward(self, x: torch.Tensor, grid_size: Tuple[int, int, int]) -> torch.Tensor:
"""
Args:
x: [B, L, D], where L = T*H*W
grid_size: (T, H, W)
Returns:
pooled: [B, T, D]
"""
B, L, D = x.shape
T, H, W = grid_size
assert D == self.dim, f"Channel dimension mismatch: D={D} vs dim={self.dim}"
assert L == T * H * W, f"Flattened length mismatch: L={L} vs T*H*W={T*H*W}"
S = H * W
# Re-arrange tokens grouped by frame.
x_bt_s_d = x.view(B, T, S, D).contiguous().view(B * T, S, D) # [B*T, S, D]
# A learnable probe as the query (one query per frame).
probe = self.probe.expand(B * T, -1, -1) # [B*T, 1, D]
# Attention pooling: query=probe, key/value=H*W tokens within the frame.
pooled_bt_1_d = self.attention(probe, x_bt_s_d, x_bt_s_d, need_weights=False)[0] # [B*T, 1, D]
pooled_bt_d = pooled_bt_1_d.squeeze(1) # [B*T, D]
# Restore to [B, T, D].
pooled = pooled_bt_d.view(B, T, D)
pooled = self.layernorm(pooled)
return pooled
class CrossModalInteractionController:
"""
Strategy class that controls interactions between two towers.
Manages the interaction mapping between visual DiT (e.g. 30 layers) and audio DiT (e.g. 30 layers).
"""
def __init__(self, visual_layers: int = 30, audio_layers: int = 30):
self.visual_layers = visual_layers
self.audio_layers = audio_layers
self.min_layers = min(visual_layers, audio_layers)
def get_interaction_layers(self, strategy: str = "shallow_focus") -> Dict[str, List[Tuple[int, int]]]:
"""
Get interaction layer mappings.
Args:
strategy: interaction strategy
- "shallow_focus": emphasize shallow layers to avoid deep-layer asymmetry
- "distributed": distributed interactions across the network
- "progressive": dense shallow interactions, sparse deeper interactions
- "custom": custom interaction layers
Returns:
A dict containing mappings for 'v2a' (visual -> audio) and 'a2v' (audio -> visual).
"""
if strategy == "shallow_focus":
# Emphasize the first ~1/3 layers to avoid deep-layer asymmetry.
num_interact = min(10, self.min_layers // 3)
interact_layers = list(range(0, num_interact))
elif strategy == "distributed":
# Distribute interactions across the network (every few layers).
step = 3
interact_layers = list(range(0, self.min_layers, step))
elif strategy == "progressive":
# Progressive: dense shallow interactions, sparse deeper interactions.
shallow = list(range(0, min(8, self.min_layers))) # Dense for the first 8 layers.
if self.min_layers > 8:
deep = list(range(8, self.min_layers, 3)) # Every 3 layers afterwards.
interact_layers = shallow + deep
else:
interact_layers = shallow
elif strategy == "custom":
# Custom strategy: adjust as needed.
interact_layers = [0, 2, 4, 6, 8, 12, 16, 20] # Explicit layer indices.
interact_layers = [i for i in interact_layers if i < self.min_layers]
elif strategy == "full":
interact_layers = list(range(0, self.min_layers))
else:
raise ValueError(f"Unknown interaction strategy: {strategy}")
# Build bidirectional mapping.
mapping = {
'v2a': [(i, i) for i in interact_layers], # visual layer i -> audio layer i
'a2v': [(i, i) for i in interact_layers] # audio layer i -> visual layer i
}
return mapping
def should_interact(self, layer_idx: int, direction: str, interaction_mapping: Dict) -> bool:
"""
Check whether a given layer should interact.
Args:
layer_idx: current layer index
direction: interaction direction ('v2a' or 'a2v')
interaction_mapping: interaction mapping table
Returns:
bool: whether to interact
"""
if direction not in interaction_mapping:
return False
return any(src == layer_idx for src, _ in interaction_mapping[direction])
class ConditionalCrossAttention(nn.Module):
def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.q_dim = dim
self.kv_dim = kv_dim
self.num_heads = num_heads
self.head_dim = self.q_dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(kv_dim, dim)
self.v = nn.Linear(kv_dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x: torch.Tensor, y: torch.Tensor, x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
ctx = y
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(ctx))
v = self.v(ctx)
if x_freqs is not None:
x_cos, x_sin = x_freqs
B, L, _ = q.shape
q_view = rearrange(q, 'b l (h d) -> b l h d', d=self.head_dim)
x_cos = x_cos.to(q_view.dtype).to(q_view.device)
x_sin = x_sin.to(q_view.dtype).to(q_view.device)
# Expect x_cos/x_sin shape: [B or 1, L, head_dim]
q_view, _ = apply_rotary_pos_emb(q_view, q_view, x_cos, x_sin, unsqueeze_dim=2)
q = rearrange(q_view, 'b l h d -> b l (h d)')
if y_freqs is not None:
y_cos, y_sin = y_freqs
Bc, Lc, _ = k.shape
k_view = rearrange(k, 'b l (h d) -> b l h d', d=self.head_dim)
y_cos = y_cos.to(k_view.dtype).to(k_view.device)
y_sin = y_sin.to(k_view.dtype).to(k_view.device)
# Expect y_cos/y_sin shape: [B or 1, L, head_dim]
_, k_view = apply_rotary_pos_emb(k_view, k_view, y_cos, y_sin, unsqueeze_dim=2)
k = rearrange(k_view, 'b l h d -> b l (h d)')
x = self.attn(q, k, v)
return self.o(x)
# from diffusers.models.attention import AdaLayerNorm
class AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
output_dim (`int`, *optional*):
norm_elementwise_affine (`bool`, defaults to `False):
norm_eps (`bool`, defaults to `False`):
chunk_dim (`int`, defaults to `0`):
"""
def __init__(
self,
embedding_dim: int,
num_embeddings: Optional[int] = None,
output_dim: Optional[int] = None,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
super().__init__()
self.chunk_dim = chunk_dim
output_dim = output_dim or embedding_dim * 2
if num_embeddings is not None:
self.emb = nn.Embedding(num_embeddings, embedding_dim)
else:
self.emb = None
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.emb is not None:
temb = self.emb(timestep)
temb = self.linear(self.silu(temb))
if self.chunk_dim == 2:
scale, shift = temb.chunk(2, dim=2)
# print(f"{x.shape = }, {scale.shape = }, {shift.shape = }")
elif self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
# other if-branch. This branch is specific to CogVideoX and OmniGen for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
else:
scale, shift = temb.chunk(2, dim=0)
x = self.norm(x) * (1 + scale) + shift
return x
class ConditionalCrossAttentionBlock(nn.Module):
"""
A thin wrapper around ConditionalCrossAttention.
Applies LayerNorm to the conditioning input `y` before cross-attention.
"""
def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6, pooled_adaln: bool = False):
super().__init__()
self.y_norm = nn.LayerNorm(kv_dim, eps=eps)
self.inner = ConditionalCrossAttention(dim=dim, kv_dim=kv_dim, num_heads=num_heads, eps=eps)
self.pooled_adaln = pooled_adaln
if pooled_adaln:
self.per_frame_pooling = PerFrameAttentionPooling(kv_dim, num_heads=num_heads, eps=eps)
self.adaln = AdaLayerNorm(kv_dim, output_dim=dim*2, chunk_dim=2)
def forward(
self,
x: torch.Tensor,
y: torch.Tensor,
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_grid_size: Optional[Tuple[int, int, int]] = None,
) -> torch.Tensor:
if self.pooled_adaln:
assert video_grid_size is not None, "video_grid_size must not be None"
pooled_y = self.per_frame_pooling(y, video_grid_size)
# Interpolate pooled_y along its temporal dimension to match x's sequence length.
if pooled_y.shape[1] != x.shape[1]:
pooled_y = F.interpolate(
pooled_y.permute(0, 2, 1), # [B, C, T]
size=x.shape[1],
mode='linear',
align_corners=False,
).permute(0, 2, 1) # [B, T, C]
x = self.adaln(x, temb=pooled_y)
y = self.y_norm(y)
return self.inner(x=x, y=y, x_freqs=x_freqs, y_freqs=y_freqs)
class DualTowerConditionalBridge(nn.Module):
"""
Dual-tower conditional bridge.
"""
def __init__(self,
visual_layers: int = 40,
audio_layers: int = 30,
visual_hidden_dim: int = 5120, # visual DiT hidden state dimension
audio_hidden_dim: int = 1536, # audio DiT hidden state dimension
audio_fps: float = 50.0,
head_dim: int = 128, # attention head dimension
interaction_strategy: str = "full",
apply_cross_rope: bool = True, # whether to apply RoPE in cross-attention
apply_first_frame_bias_in_rope: bool = False, # whether to account for 1/video_fps bias for the first frame in RoPE alignment
trainable_condition_scale: bool = False,
pooled_adaln: bool = False,
):
super().__init__()
self.visual_hidden_dim = visual_hidden_dim
self.audio_hidden_dim = audio_hidden_dim
self.audio_fps = audio_fps
self.head_dim = head_dim
self.apply_cross_rope = apply_cross_rope
self.apply_first_frame_bias_in_rope = apply_first_frame_bias_in_rope
self.trainable_condition_scale = trainable_condition_scale
self.pooled_adaln = pooled_adaln
if self.trainable_condition_scale:
self.condition_scale = nn.Parameter(torch.tensor([1.0], dtype=torch.float32))
else:
self.condition_scale = 1.0
self.controller = CrossModalInteractionController(visual_layers, audio_layers)
self.interaction_mapping = self.controller.get_interaction_layers(interaction_strategy)
# Conditional cross-attention modules operating at the DiT hidden-state level.
self.audio_to_video_conditioners = nn.ModuleDict() # audio hidden states -> visual DiT conditioning
self.video_to_audio_conditioners = nn.ModuleDict() # visual hidden states -> audio DiT conditioning
# Build conditioners for layers that should interact.
# audio hidden states condition the visual DiT
self.rotary = RotaryEmbedding(base=10000.0, dim=head_dim)
for v_layer, _ in self.interaction_mapping['a2v']:
self.audio_to_video_conditioners[str(v_layer)] = ConditionalCrossAttentionBlock(
dim=visual_hidden_dim, # 3072 (visual DiT hidden states)
kv_dim=audio_hidden_dim, # 1536 (audio DiT hidden states)
num_heads=visual_hidden_dim // head_dim, # derive number of heads from hidden dim
pooled_adaln=False # a2v typically does not need pooled AdaLN
)
# visual hidden states condition the audio DiT
for a_layer, _ in self.interaction_mapping['v2a']:
self.video_to_audio_conditioners[str(a_layer)] = ConditionalCrossAttentionBlock(
dim=audio_hidden_dim, # 1536 (audio DiT hidden states)
kv_dim=visual_hidden_dim, # 3072 (visual DiT hidden states)
num_heads=audio_hidden_dim // head_dim, # safe head count derivation
pooled_adaln=self.pooled_adaln
)
@torch.no_grad()
def build_aligned_freqs(self,
video_fps: float,
grid_size: Tuple[int, int, int],
audio_steps: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
"""
Build aligned RoPE (cos, sin) pairs based on video fps, video grid size (f_v, h, w),
and audio sequence length `audio_steps` (with fixed audio fps = 44100/2048).
Returns:
visual_freqs: (cos_v, sin_v), shape [1, f_v*h*w, head_dim]
audio_freqs: (cos_a, sin_a), shape [1, audio_steps, head_dim]
"""
f_v, h, w = grid_size
L_v = f_v * h * w
L_a = int(audio_steps)
device = device or next(self.parameters()).device
dtype = dtype or torch.float32
# Audio positions: 0,1,2,...,L_a-1 (audio as reference).
audio_pos = torch.arange(L_a, device=device, dtype=torch.float32).unsqueeze(0)
# Video positions: align video frames to audio-step units.
# FIXME(dhyu): hard-coded VAE temporal stride = 4
if self.apply_first_frame_bias_in_rope:
# Account for the "first frame lasts 1/video_fps" bias.
video_effective_fps = float(video_fps) / 4.0
if f_v > 0:
t_starts = torch.zeros((f_v,), device=device, dtype=torch.float32)
if f_v > 1:
t_starts[1:] = (1.0 / float(video_fps)) + torch.arange(f_v - 1, device=device, dtype=torch.float32) * (1.0 / video_effective_fps)
else:
t_starts = torch.zeros((0,), device=device, dtype=torch.float32)
# Convert to audio-step units.
video_pos_per_frame = t_starts * float(self.audio_fps)
else:
# No first-frame bias: uniform alignment.
scale = float(self.audio_fps) / float(video_fps / 4.0)
video_pos_per_frame = torch.arange(f_v, device=device, dtype=torch.float32) * scale
# Flatten to f*h*w; tokens within the same frame share the same time position.
video_pos = video_pos_per_frame.repeat_interleave(h * w).unsqueeze(0)
# print(f"video fps: {video_fps}, audio fps: {self.audio_fps}, scale: {scale}")
# print(f"video pos: {video_pos.shape}, audio pos: {audio_pos.shape}")
# Build dummy x to produce cos/sin, dim=head_dim.
dummy_v = torch.zeros((1, L_v, self.head_dim), device=device, dtype=dtype)
dummy_a = torch.zeros((1, L_a, self.head_dim), device=device, dtype=dtype)
cos_v, sin_v = self.rotary(dummy_v, position_ids=video_pos)
cos_a, sin_a = self.rotary(dummy_a, position_ids=audio_pos)
return (cos_v, sin_v), (cos_a, sin_a)
def should_interact(self, layer_idx: int, direction: str) -> bool:
return self.controller.should_interact(layer_idx, direction, self.interaction_mapping)
def apply_conditional_control(
self,
layer_idx: int,
direction: str,
primary_hidden_states: torch.Tensor,
condition_hidden_states: torch.Tensor,
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
condition_scale: Optional[float] = None,
video_grid_size: Optional[Tuple[int, int, int]] = None,
use_gradient_checkpointing: Optional[bool] = False,
use_gradient_checkpointing_offload: Optional[bool] = False,
) -> torch.Tensor:
"""
Apply conditional control (at the DiT hidden-state level).
Args:
layer_idx: current layer index
direction: conditioning direction
- 'a2v': audio hidden states -> visual DiT
- 'v2a': visual hidden states -> audio DiT
primary_hidden_states: primary DiT hidden states [B, L, hidden_dim]
condition_hidden_states: condition DiT hidden states [B, L, hidden_dim]
condition_scale: conditioning strength (similar to CFG scale)
Returns:
Conditioned primary DiT hidden states [B, L, hidden_dim]
"""
if not self.controller.should_interact(layer_idx, direction, self.interaction_mapping):
return primary_hidden_states
if direction == 'a2v':
# audio hidden states condition the visual DiT
conditioner = self.audio_to_video_conditioners[str(layer_idx)]
elif direction == 'v2a':
# visual hidden states condition the audio DiT
conditioner = self.video_to_audio_conditioners[str(layer_idx)]
else:
raise ValueError(f"Invalid direction: {direction}")
conditioned_features = gradient_checkpoint_forward(
conditioner,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x=primary_hidden_states,
y=condition_hidden_states,
x_freqs=x_freqs,
y_freqs=y_freqs,
video_grid_size=video_grid_size,
)
if self.trainable_condition_scale and condition_scale is not None:
print(
"[WARN] This model has a trainable condition_scale, but an external "
f"condition_scale={condition_scale} was provided. The trainable condition_scale "
"will be ignored in favor of the external value."
)
scale = condition_scale if condition_scale is not None else self.condition_scale
primary_hidden_states = primary_hidden_states + conditioned_features * scale
return primary_hidden_states
def forward(
self,
layer_idx: int,
visual_hidden_states: torch.Tensor,
audio_hidden_states: torch.Tensor,
*,
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
a2v_condition_scale: Optional[float] = None,
v2a_condition_scale: Optional[float] = None,
condition_scale: Optional[float] = None,
video_grid_size: Optional[Tuple[int, int, int]] = None,
use_gradient_checkpointing: Optional[bool] = False,
use_gradient_checkpointing_offload: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply bidirectional conditional control to both visual/audio towers.
Args:
layer_idx: current layer index
visual_hidden_states: visual DiT hidden states
audio_hidden_states: audio DiT hidden states
x_freqs / y_freqs: cross-modal RoPE (cos, sin) pairs.
If provided, x_freqs is assumed to correspond to the primary tower and y_freqs
to the conditioning tower.
a2v_condition_scale: audio->visual conditioning strength (overrides global condition_scale)
v2a_condition_scale: visual->audio conditioning strength (overrides global condition_scale)
condition_scale: fallback conditioning strength when per-direction scale is None
video_grid_size: (F, H, W), used on the audio side when pooled_adaln is enabled
Returns:
(visual_hidden_states, audio_hidden_states), both conditioned in their respective directions.
"""
visual_conditioned = self.apply_conditional_control(
layer_idx=layer_idx,
direction="a2v",
primary_hidden_states=visual_hidden_states,
condition_hidden_states=audio_hidden_states,
x_freqs=x_freqs,
y_freqs=y_freqs,
condition_scale=a2v_condition_scale if a2v_condition_scale is not None else condition_scale,
video_grid_size=video_grid_size,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
audio_conditioned = self.apply_conditional_control(
layer_idx=layer_idx,
direction="v2a",
primary_hidden_states=audio_hidden_states,
condition_hidden_states=visual_hidden_states,
x_freqs=y_freqs,
y_freqs=x_freqs,
condition_scale=v2a_condition_scale if v2a_condition_scale is not None else condition_scale,
video_grid_size=video_grid_size,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return visual_conditioned, audio_conditioned

View File

@@ -99,18 +99,30 @@ def rope_apply(x, freqs, num_heads):
return x_out.to(x.dtype)
def set_to_torch_norm(models):
for model in models:
for module in model.modules():
if isinstance(module, RMSNorm):
module.use_torch_norm = True
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.use_torch_norm = False
self.normalized_shape = (dim,)
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
return self.norm(x.float()).to(dtype) * self.weight
if self.use_torch_norm:
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
else:
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):