This commit is contained in:
mi804
2026-04-21 19:42:57 +08:00
parent 9d09e0431c
commit 95cfb77881
6 changed files with 135 additions and 261 deletions

View File

@@ -42,10 +42,10 @@ class AceStepPipeline(BasePipeline):
AceStepUnit_PromptEmbedder(),
AceStepUnit_ReferenceAudioEmbedder(),
AceStepUnit_ConditionEmbedder(),
AceStepUnit_AudioCodeDecoder(),
AceStepUnit_ContextLatentBuilder(),
AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(),
AceStepUnit_AudioCodeDecoder(),
]
self.model_fn = model_fn_ace_step
self.compilable_models = ["dit"]
@@ -92,27 +92,27 @@ class AceStepPipeline(BasePipeline):
cfg_scale: float = 1.0,
# Lyrics
lyrics: str = "",
# Reference audio (optional, for timbre conditioning)
# Reference audio
reference_audios: List[torch.Tensor] = None,
# Src audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0,
# Simple Mode: LLM-generated audio codes (optional)
audio_codes: str = None,
# Audio codes
audio_code_string: Optional[str] = None,
# Shape
duration: int = 60,
# Audio Meta
bpm: Optional[int] = 100,
keyscale: Optional[str] = "B minor",
timesignature: Optional[str] = "4",
vocal_language: Optional[str] = 'zh',
vocal_language: Optional[str] = 'unknown',
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 8,
# Scheduler-specific parameters
shift: float = 3.0,
shift: float = 1.0,
# Progress
progress_bar_cmd=tqdm,
):
@@ -120,14 +120,14 @@ class AceStepPipeline(BasePipeline):
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
# 2. 三字典输入
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_posi = {"prompt": prompt, "positive": True}
inputs_nega = {"positive": False}
inputs_shared = {
"cfg_scale": cfg_scale,
"lyrics": lyrics,
"reference_audios": reference_audios,
"src_audio": src_audio,
"audio_codes": audio_codes,
"audio_code_string": audio_code_string,
"duration": duration,
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
"seed": seed,
@@ -145,12 +145,13 @@ class AceStepPipeline(BasePipeline):
# 4. Denoise loop
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
self.momentum_buffer = MomentumBuffer()
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
**models, timestep=timestep, progress_id=progress_id,
)
inputs_shared["latents"] = self.step(
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
@@ -163,39 +164,12 @@ class AceStepPipeline(BasePipeline):
vae_output = self.vae.decode(latents)
# VAE returns OobleckDecoderOutput with .sample attribute
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
audio_output = self.normalize_audio(audio_output, target_db=-1.0)
audio = self.output_audio_format_check(audio_output)
# Peak normalization to match target library behavior
audio = self.normalize_audio(audio, target_db=-1.0)
self.load_models_to_device([])
return audio
def output_audio_format_check(self, audio_output):
"""Convert VAE output to standard audio format [C, T], float32.
VAE decode outputs [B, C, T] (audio waveform).
We squeeze batch dim and return [C, T].
"""
if audio_output.ndim == 3:
audio_output = audio_output.squeeze(0)
return audio_output.float()
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
"""Apply peak normalization to audio data, matching target library behavior.
Target library reference: `acestep/audio_utils.py:normalize_audio()`
peak = max(abs(audio))
gain = 10^(target_db/20) / peak
audio = audio * gain
Args:
audio: Audio tensor [C, T]
target_db: Target peak level in dB (default: -1.0)
Returns:
Normalized audio tensor
"""
peak = torch.max(torch.abs(audio))
if peak < 1e-6:
return audio
@@ -203,17 +177,46 @@ class AceStepPipeline(BasePipeline):
gain = target_amp / peak
return audio * gain
class AceStepUnit_TaskTypeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
input_params=("src_audio", "audio_code_string"),
output_params=("task_type",),
)
def process(self, pipe, src_audio, audio_code_string):
if audio_code_string is not None:
print("audio_code_string detected, setting task_type to 'cover'")
task_type = "cover"
else:
task_type = "text2music"
return {"task_type": task_type}
class AceStepUnit_PromptEmbedder(PipelineUnit):
SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
INSTRUCTION_MAP = {
"text2music": "Fill the audio semantic mask based on the given conditions:",
"cover": "Generate audio semantic tokens based on the given conditions:",
"repaint": "Repaint the mask area based on the given conditions:",
"extract": "Extract the {TRACK_NAME} track from the audio:",
"extract_default": "Extract the track from the audio:",
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
"lego_default": "Generate the track based on the audio context:",
"complete": "Complete the input track with {TRACK_CLASSES}:",
"complete_default": "Complete the input track:",
}
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "prompt"},
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language"),
input_params_posi={"prompt": "prompt", "positive": "positive"},
input_params_nega={"prompt": "prompt", "positive": "positive"},
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language", "task_type"),
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
onload_model_names=("text_encoder",)
)
@@ -256,10 +259,13 @@ class AceStepUnit_PromptEmbedder(PipelineUnit):
f"- duration: {duration}\n"
)
def process(self, pipe, prompt, lyrics, duration, bpm, keyscale, timesignature, vocal_language):
def process(self, pipe, prompt, positive, lyrics, duration, bpm, keyscale, timesignature, vocal_language, task_type):
if not positive:
return {}
pipe.load_models_to_device(['text_encoder'])
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
prompt = self.SFT_GEN_PROMPT.format(self.INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
INSTRUCTION = self.INSTRUCTION_MAP.get(task_type, self.INSTRUCTION_MAP["text2music"])
prompt = self.SFT_GEN_PROMPT.format(INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
@@ -350,31 +356,32 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
input_params_nega={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
input_params=("reference_latents", "refer_audio_order_mask"),
take_over=True,
output_params=("encoder_hidden_states", "encoder_attention_mask"),
onload_model_names=("conditioner",),
)
def process(self, pipe, text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, reference_latents, refer_audio_order_mask):
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
pipe.load_models_to_device(['conditioner'])
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
reference_latents=reference_latents,
refer_audio_order_mask=refer_audio_order_mask,
text_hidden_states=inputs_posi.get("text_hidden_states", None),
text_attention_mask=inputs_posi.get("text_attention_mask", None),
lyric_hidden_states=inputs_posi.get("lyric_hidden_states", None),
lyric_attention_mask=inputs_posi.get("lyric_attention_mask", None),
reference_latents=inputs_shared.get("reference_latents", None),
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
)
return {"encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask}
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states)
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
return inputs_shared, inputs_posi, inputs_nega
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("duration", "src_audio"),
input_params=("duration", "src_audio", "lm_hints"),
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
)
@@ -386,9 +393,15 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
tiled = pipe.silence_latent[0].repeat(repeats, 1)
return tiled[:length, :]
def process(self, pipe, duration, src_audio):
if src_audio is not None:
raise NotImplementedError("Src audio conditioning is not implemented yet. Please set src_audio to None.")
def process(self, pipe, duration, src_audio, lm_hints):
if lm_hints is not None:
max_latent_length = lm_hints.shape[1]
src_latents = lm_hints.clone()
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
elif src_audio is not None:
raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
else:
max_latent_length = duration * pipe.sample_rate // 1920
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
@@ -410,6 +423,7 @@ class AceStepUnit_NoiseInitializer(PipelineUnit):
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
@@ -427,8 +441,8 @@ class AceStepUnit_InputAudioEmbedder(PipelineUnit):
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("audio_codes", "seq_len", "silence_latent"),
output_params=("lm_hints_25Hz",),
input_params=("audio_code_string",),
output_params=("lm_hints",),
onload_model_names=("tokenizer_model",),
)
@@ -437,45 +451,29 @@ class AceStepUnit_AudioCodeDecoder(PipelineUnit):
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
if not code_str:
return []
codes = []
max_audio_code = 63999
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
code_value = int(x)
codes.append(max(0, min(code_value, max_audio_code)))
try:
codes = []
max_audio_code = 63999
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
code_value = int(x)
codes.append(max(0, min(code_value, max_audio_code)))
except Exception as e:
raise ValueError(f"Invalid audio_code_string format: {e}")
return codes
def process(self, pipe, audio_codes, seq_len, silence_latent):
if audio_codes is None or not audio_codes.strip():
return {"lm_hints_25Hz": None}
code_ids = self._parse_audio_code_string(audio_codes)
def process(self, pipe, audio_code_string):
if audio_code_string is None or not audio_code_string.strip():
return {"lm_hints": None}
code_ids = self._parse_audio_code_string(audio_code_string)
if len(code_ids) == 0:
return {"lm_hints_25Hz": None}
return {"lm_hints": None}
pipe.load_models_to_device(["tokenizer_model"])
quantizer = pipe.tokenizer_model.tokenizer.quantizer
detokenizer = pipe.tokenizer_model.detokenizer
indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
quantized = quantizer.get_output_from_indices(indices) # [1, N, 2048]
if quantized.dtype != pipe.torch_dtype:
quantized = quantized.to(pipe.torch_dtype)
lm_hints = detokenizer(quantized) # [1, N*5, 64]
# Pad or truncate to seq_len
current_len = lm_hints.shape[1]
if current_len < seq_len:
pad_len = seq_len - current_len
pad = silence_latent[:, :pad_len, :]
lm_hints = torch.cat([lm_hints, pad], dim=1)
elif current_len > seq_len:
lm_hints = lm_hints[:, :seq_len, :]
return {"lm_hints_25Hz": lm_hints}
quantized = pipe.tokenizer_model.tokenizer.quantizer.get_output_from_indices(indices).to(pipe.torch_dtype) # [1, N, 2048]
lm_hints = pipe.tokenizer_model.detokenizer(quantized) # [1, N*5, 64]
return {"lm_hints": lm_hints}
def model_fn_ace_step(
@@ -499,5 +497,7 @@ def model_fn_ace_step(
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)[0]
return decoder_outputs