ace-step train

This commit is contained in:
mi804
2026-04-22 17:58:10 +08:00
parent b0680ef711
commit c53c813c12
42 changed files with 1235 additions and 30 deletions

View File

@@ -7,6 +7,7 @@ import re
import torch
from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
import random
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
@@ -89,13 +90,14 @@ class AceStepPipeline(BasePipeline):
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Lyrics
lyrics: str = "",
# Task type
task_type: Optional[str] = "text2music",
# Reference audio
reference_audios: List[torch.Tensor] = None,
# Src audio
# Source audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0,
# Audio codes
@@ -126,6 +128,7 @@ class AceStepPipeline(BasePipeline):
inputs_shared = {
"cfg_scale": cfg_scale,
"lyrics": lyrics,
"task_type": task_type,
"reference_audios": reference_audios,
"src_audio": src_audio,
"audio_code_string": audio_code_string,
@@ -147,7 +150,7 @@ class AceStepPipeline(BasePipeline):
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.to(dtype=self.torch_dtype, device=self.device)
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
@@ -182,13 +185,14 @@ class AceStepUnit_TaskTypeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
input_params=("src_audio", "audio_code_string"),
input_params=("audio_code_string"),
output_params=("task_type",),
)
def process(self, pipe, src_audio, audio_code_string):
def process(self, pipe, audio_code_string):
if pipe.scheduler.training:
return {"task_type": "text2music"}
if audio_code_string is not None:
print("audio_code_string detected, setting task_type to 'cover'")
task_type = "cover"
else:
task_type = "text2music"
@@ -200,7 +204,6 @@ class AceStepUnit_PromptEmbedder(PipelineUnit):
INSTRUCTION_MAP = {
"text2music": "Fill the audio semantic mask based on the given conditions:",
"cover": "Generate audio semantic tokens based on the given conditions:",
"repaint": "Repaint the mask area based on the given conditions:",
"extract": "Extract the {TRACK_NAME} track from the audio:",
"extract_default": "Extract the track from the audio:",
@@ -292,6 +295,7 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
def process(self, pipe, reference_audios):
pipe.load_models_to_device(['vae'])
if reference_audios is not None and len(reference_audios) > 0:
raise NotImplementedError("Reference audio embedding is not implemented yet.")
# TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask
pass
else:
@@ -299,6 +303,49 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
# def process_reference_audio(self, reference_audios) -> Optional[torch.Tensor]:
# try:
# audio_np, sr = _read_audio_file(audio_file)
# audio = self._numpy_to_channels_first(audio_np)
# logger.debug(
# f"[process_reference_audio] Reference audio shape: {audio.shape}"
# )
# logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
# logger.debug(
# f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / sr:.6f} seconds"
# )
# audio = self._normalize_audio_to_stereo_48k(audio, sr)
# if self.is_silence(audio):
# return None
# target_frames = 30 * 48000
# segment_frames = 10 * 48000
# if audio.shape[-1] < target_frames:
# repeat_times = math.ceil(target_frames / audio.shape[-1])
# audio = audio.repeat(1, repeat_times)
# total_frames = audio.shape[-1]
# segment_size = total_frames // 3
# front_start = random.randint(0, max(0, segment_size - segment_frames))
# front_audio = audio[:, front_start : front_start + segment_frames]
# middle_start = segment_size + random.randint(
# 0, max(0, segment_size - segment_frames)
# )
# middle_audio = audio[:, middle_start : middle_start + segment_frames]
# back_start = 2 * segment_size + random.randint(
# 0, max(0, (total_frames - 2 * segment_size) - segment_frames)
# )
# back_audio = audio[:, back_start : back_start + segment_frames]
# return torch.cat([front_audio, middle_audio, back_audio], dim=-1)
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Infer packed reference-audio latents and order mask."""
refer_audio_order_mask = []
@@ -401,8 +448,8 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
elif src_audio is not None:
raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
# elif src_audio is not None:
# raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
else:
max_latent_length = duration * pipe.sample_rate // 1920
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
@@ -435,8 +482,16 @@ class AceStepUnit_InputAudioEmbedder(PipelineUnit):
def process(self, pipe, noise, input_audio):
if input_audio is None:
return {"latents": noise}
# TODO: support for train
return {"latents": noise, "input_latents": None}
if pipe.scheduler.training:
pipe.load_models_to_device(['vae'])
input_audio, sample_rate = input_audio
input_audio = torch.clamp(input_audio, -1.0, 1.0)
if input_audio.dim() == 2:
input_audio = input_audio.unsqueeze(0)
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
input_latents = input_latents[:, :noise.shape[1]]
return {"input_latents": input_latents}
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
@@ -494,7 +549,6 @@ def model_fn_ace_step(
use_gradient_checkpointing_offload=False,
**kwargs,
):
timestep = timestep.unsqueeze(0)
decoder_outputs = dit(
hidden_states=latents,
timestep=timestep,