ace-step train

This commit is contained in:
mi804
2026-04-22 17:58:10 +08:00
parent b0680ef711
commit c53c813c12
42 changed files with 1235 additions and 30 deletions

View File

@@ -3,6 +3,7 @@ import torch, torchvision, imageio, os
import imageio.v3 as iio
from PIL import Image
import torchaudio
from diffsynth.utils.data.audio import read_audio
class DataProcessingPipeline:
@@ -276,3 +277,27 @@ class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
except:
warnings.warn(f"Cannot load audio in {data}. The audio will be `None`.")
return None
class LoadPureAudioWithTorchaudio(DataProcessingOperator):
def __init__(self, target_sample_rate=None, target_duration=None):
self.target_sample_rate = target_sample_rate
self.target_duration = target_duration
self.resample = True if target_sample_rate is not None else False
def __call__(self, data: str):
try:
waveform, sample_rate = read_audio(data, resample=self.resample, resample_rate=self.target_sample_rate)
if self.target_duration is not None:
target_samples = int(self.target_duration * sample_rate)
current_samples = waveform.shape[-1]
if current_samples > target_samples:
waveform = waveform[..., :target_samples]
elif current_samples < target_samples:
padding = target_samples - current_samples
waveform = torch.nn.functional.pad(waveform, (0, padding))
return waveform, sample_rate
except Exception as e:
warnings.warn(f"Cannot load audio in '{data}' due to '{e}'. The audio will be `None`.")
return None

View File

@@ -864,20 +864,13 @@ class AceStepDiTModel(nn.Module):
layer_kwargs = flash_attn_kwargs
# Use gradient checkpointing if enabled
if use_gradient_checkpointing or use_gradient_checkpointing_offload:
layer_outputs = gradient_checkpoint_forward(
layer_module,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*layer_args,
**layer_kwargs,
)
else:
layer_outputs = layer_module(
*layer_args,
**layer_kwargs,
)
layer_outputs = gradient_checkpoint_forward(
layer_module,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*layer_args,
**layer_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions and self.layers[index_block].use_cross_attention:

View File

@@ -191,6 +191,43 @@ class OobleckDecoder(nn.Module):
return self.conv2(hidden_state)
class OobleckDiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.scale = parameters.chunk(2, dim=1)
self.std = nn.functional.softplus(self.scale) + 1e-4
self.var = self.std * self.std
self.logvar = torch.log(self.var)
self.deterministic = deterministic
def sample(self, generator: torch.Generator | None = None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = torch.randn(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
else:
normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
var_ratio = self.var / other.var
logvar_diff = self.logvar - other.logvar
kl = normalized_diff + var_ratio + logvar_diff - 1
kl = kl.sum(1).mean()
return kl
class AceStepVAE(nn.Module):
"""Audio VAE for ACE-Step (AutoencoderOobleck architecture).
@@ -229,17 +266,19 @@ class AceStepVAE(nn.Module):
self.sampling_rate = sampling_rate
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Audio waveform [B, audio_channels, T] → latent [B, encoder_hidden_size, T']."""
return self.encoder(x)
"""Audio waveform [B, audio_channels, T] → latent [B, decoder_input_channels, T']."""
h = self.encoder(x)
output = OobleckDiagonalGaussianDistribution(h).sample()
return output
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""Latent [B, encoder_hidden_size, T] → audio waveform [B, audio_channels, T']."""
"""Latent [B, decoder_input_channels, T] → audio waveform [B, audio_channels, T']."""
return self.decoder(z)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
"""Full round-trip: encode → decode."""
z = self.encode(sample)
return self.decoder(z)
return self.decode(z)
def remove_weight_norm(self):
"""Remove weight normalization from all conv layers (for export/inference)."""

View File

@@ -7,6 +7,7 @@ import re
import torch
from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
import random
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
@@ -89,13 +90,14 @@ class AceStepPipeline(BasePipeline):
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Lyrics
lyrics: str = "",
# Task type
task_type: Optional[str] = "text2music",
# Reference audio
reference_audios: List[torch.Tensor] = None,
# Src audio
# Source audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0,
# Audio codes
@@ -126,6 +128,7 @@ class AceStepPipeline(BasePipeline):
inputs_shared = {
"cfg_scale": cfg_scale,
"lyrics": lyrics,
"task_type": task_type,
"reference_audios": reference_audios,
"src_audio": src_audio,
"audio_code_string": audio_code_string,
@@ -147,7 +150,7 @@ class AceStepPipeline(BasePipeline):
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.to(dtype=self.torch_dtype, device=self.device)
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
@@ -182,13 +185,14 @@ class AceStepUnit_TaskTypeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
input_params=("src_audio", "audio_code_string"),
input_params=("audio_code_string"),
output_params=("task_type",),
)
def process(self, pipe, src_audio, audio_code_string):
def process(self, pipe, audio_code_string):
if pipe.scheduler.training:
return {"task_type": "text2music"}
if audio_code_string is not None:
print("audio_code_string detected, setting task_type to 'cover'")
task_type = "cover"
else:
task_type = "text2music"
@@ -200,7 +204,6 @@ class AceStepUnit_PromptEmbedder(PipelineUnit):
INSTRUCTION_MAP = {
"text2music": "Fill the audio semantic mask based on the given conditions:",
"cover": "Generate audio semantic tokens based on the given conditions:",
"repaint": "Repaint the mask area based on the given conditions:",
"extract": "Extract the {TRACK_NAME} track from the audio:",
"extract_default": "Extract the track from the audio:",
@@ -292,6 +295,7 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
def process(self, pipe, reference_audios):
pipe.load_models_to_device(['vae'])
if reference_audios is not None and len(reference_audios) > 0:
raise NotImplementedError("Reference audio embedding is not implemented yet.")
# TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask
pass
else:
@@ -299,6 +303,49 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
# def process_reference_audio(self, reference_audios) -> Optional[torch.Tensor]:
# try:
# audio_np, sr = _read_audio_file(audio_file)
# audio = self._numpy_to_channels_first(audio_np)
# logger.debug(
# f"[process_reference_audio] Reference audio shape: {audio.shape}"
# )
# logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
# logger.debug(
# f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / sr:.6f} seconds"
# )
# audio = self._normalize_audio_to_stereo_48k(audio, sr)
# if self.is_silence(audio):
# return None
# target_frames = 30 * 48000
# segment_frames = 10 * 48000
# if audio.shape[-1] < target_frames:
# repeat_times = math.ceil(target_frames / audio.shape[-1])
# audio = audio.repeat(1, repeat_times)
# total_frames = audio.shape[-1]
# segment_size = total_frames // 3
# front_start = random.randint(0, max(0, segment_size - segment_frames))
# front_audio = audio[:, front_start : front_start + segment_frames]
# middle_start = segment_size + random.randint(
# 0, max(0, segment_size - segment_frames)
# )
# middle_audio = audio[:, middle_start : middle_start + segment_frames]
# back_start = 2 * segment_size + random.randint(
# 0, max(0, (total_frames - 2 * segment_size) - segment_frames)
# )
# back_audio = audio[:, back_start : back_start + segment_frames]
# return torch.cat([front_audio, middle_audio, back_audio], dim=-1)
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Infer packed reference-audio latents and order mask."""
refer_audio_order_mask = []
@@ -401,8 +448,8 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
elif src_audio is not None:
raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
# elif src_audio is not None:
# raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
else:
max_latent_length = duration * pipe.sample_rate // 1920
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
@@ -435,8 +482,16 @@ class AceStepUnit_InputAudioEmbedder(PipelineUnit):
def process(self, pipe, noise, input_audio):
if input_audio is None:
return {"latents": noise}
# TODO: support for train
return {"latents": noise, "input_latents": None}
if pipe.scheduler.training:
pipe.load_models_to_device(['vae'])
input_audio, sample_rate = input_audio
input_audio = torch.clamp(input_audio, -1.0, 1.0)
if input_audio.dim() == 2:
input_audio = input_audio.unsqueeze(0)
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
input_latents = input_latents[:, :noise.shape[1]]
return {"input_latents": input_latents}
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
@@ -494,7 +549,6 @@ def model_fn_ace_step(
use_gradient_checkpointing_offload=False,
**kwargs,
):
timestep = timestep.unsqueeze(0)
decoder_outputs = dit(
hidden_states=latents,
timestep=timestep,