ace-step train

This commit is contained in:
mi804
2026-04-22 17:58:10 +08:00
parent b0680ef711
commit c53c813c12
42 changed files with 1235 additions and 30 deletions

View File

@@ -864,20 +864,13 @@ class AceStepDiTModel(nn.Module):
layer_kwargs = flash_attn_kwargs
# Use gradient checkpointing if enabled
if use_gradient_checkpointing or use_gradient_checkpointing_offload:
layer_outputs = gradient_checkpoint_forward(
layer_module,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*layer_args,
**layer_kwargs,
)
else:
layer_outputs = layer_module(
*layer_args,
**layer_kwargs,
)
layer_outputs = gradient_checkpoint_forward(
layer_module,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*layer_args,
**layer_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions and self.layers[index_block].use_cross_attention:

View File

@@ -191,6 +191,43 @@ class OobleckDecoder(nn.Module):
return self.conv2(hidden_state)
class OobleckDiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.scale = parameters.chunk(2, dim=1)
self.std = nn.functional.softplus(self.scale) + 1e-4
self.var = self.std * self.std
self.logvar = torch.log(self.var)
self.deterministic = deterministic
def sample(self, generator: torch.Generator | None = None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = torch.randn(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
else:
normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
var_ratio = self.var / other.var
logvar_diff = self.logvar - other.logvar
kl = normalized_diff + var_ratio + logvar_diff - 1
kl = kl.sum(1).mean()
return kl
class AceStepVAE(nn.Module):
"""Audio VAE for ACE-Step (AutoencoderOobleck architecture).
@@ -229,17 +266,19 @@ class AceStepVAE(nn.Module):
self.sampling_rate = sampling_rate
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Audio waveform [B, audio_channels, T] → latent [B, encoder_hidden_size, T']."""
return self.encoder(x)
"""Audio waveform [B, audio_channels, T] → latent [B, decoder_input_channels, T']."""
h = self.encoder(x)
output = OobleckDiagonalGaussianDistribution(h).sample()
return output
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""Latent [B, encoder_hidden_size, T] → audio waveform [B, audio_channels, T']."""
"""Latent [B, decoder_input_channels, T] → audio waveform [B, audio_channels, T']."""
return self.decoder(z)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
"""Full round-trip: encode → decode."""
z = self.encode(sample)
return self.decoder(z)
return self.decode(z)
def remove_weight_norm(self):
"""Remove weight normalization from all conv layers (for export/inference)."""