mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
ace-step train
This commit is contained in:
@@ -864,20 +864,13 @@ class AceStepDiTModel(nn.Module):
|
||||
layer_kwargs = flash_attn_kwargs
|
||||
|
||||
# Use gradient checkpointing if enabled
|
||||
if use_gradient_checkpointing or use_gradient_checkpointing_offload:
|
||||
layer_outputs = gradient_checkpoint_forward(
|
||||
layer_module,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
*layer_args,
|
||||
**layer_kwargs,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
*layer_args,
|
||||
**layer_kwargs,
|
||||
)
|
||||
|
||||
layer_outputs = gradient_checkpoint_forward(
|
||||
layer_module,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
*layer_args,
|
||||
**layer_kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions and self.layers[index_block].use_cross_attention:
|
||||
|
||||
@@ -191,6 +191,43 @@ class OobleckDecoder(nn.Module):
|
||||
return self.conv2(hidden_state)
|
||||
|
||||
|
||||
class OobleckDiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.scale = parameters.chunk(2, dim=1)
|
||||
self.std = nn.functional.softplus(self.scale) + 1e-4
|
||||
self.var = self.std * self.std
|
||||
self.logvar = torch.log(self.var)
|
||||
self.deterministic = deterministic
|
||||
|
||||
def sample(self, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
# make sure sample is on the same device as the parameters and has same dtype
|
||||
sample = torch.randn(
|
||||
self.mean.shape,
|
||||
generator=generator,
|
||||
device=self.parameters.device,
|
||||
dtype=self.parameters.dtype,
|
||||
)
|
||||
x = self.mean + self.std * sample
|
||||
return x
|
||||
|
||||
def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None:
|
||||
return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
|
||||
else:
|
||||
normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
|
||||
var_ratio = self.var / other.var
|
||||
logvar_diff = self.logvar - other.logvar
|
||||
|
||||
kl = normalized_diff + var_ratio + logvar_diff - 1
|
||||
|
||||
kl = kl.sum(1).mean()
|
||||
return kl
|
||||
|
||||
|
||||
class AceStepVAE(nn.Module):
|
||||
"""Audio VAE for ACE-Step (AutoencoderOobleck architecture).
|
||||
|
||||
@@ -229,17 +266,19 @@ class AceStepVAE(nn.Module):
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Audio waveform [B, audio_channels, T] → latent [B, encoder_hidden_size, T']."""
|
||||
return self.encoder(x)
|
||||
"""Audio waveform [B, audio_channels, T] → latent [B, decoder_input_channels, T']."""
|
||||
h = self.encoder(x)
|
||||
output = OobleckDiagonalGaussianDistribution(h).sample()
|
||||
return output
|
||||
|
||||
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
"""Latent [B, encoder_hidden_size, T] → audio waveform [B, audio_channels, T']."""
|
||||
"""Latent [B, decoder_input_channels, T] → audio waveform [B, audio_channels, T']."""
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
"""Full round-trip: encode → decode."""
|
||||
z = self.encode(sample)
|
||||
return self.decoder(z)
|
||||
return self.decode(z)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Remove weight normalization from all conv layers (for export/inference)."""
|
||||
|
||||
Reference in New Issue
Block a user