This commit is contained in:
mi804
2026-04-22 21:36:30 +08:00
parent f2e3427566
commit 1186379139
2 changed files with 150 additions and 89 deletions

View File

@@ -9,6 +9,8 @@ from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
import random
import math
import torch.nn.functional as F
from einops import rearrange
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
@@ -41,11 +43,11 @@ class AceStepPipeline(BasePipeline):
self.in_iteration_models = ("dit",)
self.units = [
AceStepUnit_TaskTypeChecker(),
AceStepUnit_PromptEmbedder(),
AceStepUnit_ReferenceAudioEmbedder(),
AceStepUnit_ConditionEmbedder(),
AceStepUnit_AudioCodeDecoder(),
AceStepUnit_ContextLatentBuilder(),
AceStepUnit_ConditionEmbedder(),
AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(),
]
@@ -100,7 +102,8 @@ class AceStepPipeline(BasePipeline):
reference_audios: List[torch.Tensor] = None,
# Source audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0,
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
audio_cover_strength: float = 1.0,
# Audio codes
audio_code_string: Optional[str] = None,
# Shape
@@ -121,7 +124,7 @@ class AceStepPipeline(BasePipeline):
progress_bar_cmd=tqdm,
):
# 1. Scheduler
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
# 2. 三字典输入
inputs_posi = {"prompt": prompt, "positive": True}
@@ -132,6 +135,7 @@ class AceStepPipeline(BasePipeline):
"task_type": task_type,
"reference_audios": reference_audios,
"src_audio": src_audio,
"audio_cover_strength": audio_cover_strength,
"audio_code_string": audio_code_string,
"duration": duration,
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
@@ -152,6 +156,7 @@ class AceStepPipeline(BasePipeline):
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
self.switch_noncover_condition(inputs_shared, inputs_posi, inputs_nega, progress_id)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
@@ -181,23 +186,28 @@ class AceStepPipeline(BasePipeline):
gain = target_amp / peak
return audio * gain
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0 or inputs_shared.get("shared_noncover", None) is None:
return
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
if progress_id >= cover_steps:
inputs_shared.update(inputs_shared.pop("shared_noncover", {}))
inputs_posi.update(inputs_shared.pop("posi_noncover", {}))
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega.update(inputs_shared.pop("nega_noncover", {}))
class AceStepUnit_TaskTypeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
input_params=("audio_code_string"),
input_params=("task_type",),
output_params=("task_type",),
)
def process(self, pipe, audio_code_string):
if pipe.scheduler.training:
return {"task_type": "text2music"}
if audio_code_string is not None:
task_type = "cover"
else:
task_type = "text2music"
return {"task_type": task_type}
def process(self, pipe, task_type):
assert task_type in ["text2music", "cover", "repaint"], f"Unsupported task_type: {task_type}"
return {}
class AceStepUnit_PromptEmbedder(PipelineUnit):
@@ -364,14 +374,34 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
pipe, inputs_posi["prompt"], True, inputs_shared["lyrics"], inputs_shared["duration"],
inputs_shared["bpm"], inputs_shared["keyscale"], inputs_shared["timesignature"],
inputs_shared["vocal_language"], "text2music")
encoder_hidden_states_noncover, encoder_attention_mask_noncover = pipe.conditioner(
**hidden_states_noncover,
reference_latents=inputs_shared.get("reference_latents", None),
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
)
duration = inputs_shared["context_latents"].shape[1] * 1920 / pipe.vae.sampling_rate
context_latents_noncover = AceStepUnit_ContextLatentBuilder().process(pipe, duration, None, None)["context_latents"]
inputs_shared["shared_noncover"] = {"context_latents": context_latents_noncover}
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
if inputs_shared["cfg_scale"] != 1.0:
inputs_shared["nega_noncover"] = {
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device),
"encoder_attention_mask": encoder_attention_mask_noncover,
}
return inputs_shared, inputs_posi, inputs_nega
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("duration", "src_audio", "lm_hints"),
input_params=("duration", "src_audio", "audio_code_string"),
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
onload_model_names=("vae", "tokenizer_model",),
)
def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
@@ -382,66 +412,13 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
tiled = pipe.silence_latent[0].repeat(repeats, 1)
return tiled[:length, :]
def process(self, pipe, duration, src_audio, lm_hints):
if lm_hints is not None:
max_latent_length = lm_hints.shape[1]
src_latents = lm_hints.clone()
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
# elif src_audio is not None:
# raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
else:
max_latent_length = duration * pipe.sample_rate // 1920
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
return {"context_latents": context_latents, "attention_mask": attention_mask}
class AceStepUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("context_latents", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe, context_latents, seed, rand_device):
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("noise", "input_audio"),
output_params=("latents", "input_latents"),
)
def process(self, pipe, noise, input_audio):
if input_audio is None:
return {"latents": noise}
if pipe.scheduler.training:
pipe.load_models_to_device(['vae'])
input_audio, sample_rate = input_audio
input_audio = torch.clamp(input_audio, -1.0, 1.0)
if input_audio.dim() == 2:
input_audio = input_audio.unsqueeze(0)
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
input_latents = input_latents[:, :noise.shape[1]]
return {"input_latents": input_latents}
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("audio_code_string",),
output_params=("lm_hints",),
onload_model_names=("tokenizer_model",),
)
def tokenize(self, tokenizer, x, silence_latent, pool_window_size):
if x.shape[1] % pool_window_size != 0:
pad_len = pool_window_size - (x.shape[1] % pool_window_size)
x = torch.cat([x, silence_latent[:1,:pad_len].repeat(x.shape[0],1,1)], dim=1)
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=pool_window_size)
quantized, indices = tokenizer(x)
return quantized
@staticmethod
def _parse_audio_code_string(code_str: str) -> list:
@@ -458,24 +435,72 @@ class AceStepUnit_AudioCodeDecoder(PipelineUnit):
raise ValueError(f"Invalid audio_code_string format: {e}")
return codes
def process(self, pipe, audio_code_string):
if audio_code_string is None or not audio_code_string.strip():
return {"lm_hints": None}
code_ids = self._parse_audio_code_string(audio_code_string)
if len(code_ids) == 0:
return {"lm_hints": None}
def process(self, pipe, duration, src_audio, audio_code_string):
# get src_latents from audio_code_string > src_audio > silence
if audio_code_string is not None:
pipe.load_models_to_device(self.onload_model_names)
code_ids = self._parse_audio_code_string(audio_code_string)
quantizer = pipe.tokenizer_model.tokenizer.quantizer
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
codes = quantizer.get_codes_from_indices(indices)
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
quantized = quantizer.project_out(quantized)
src_latents = pipe.tokenizer_model.detokenizer(quantized).to(pipe.device)
max_latent_length = src_latents.shape[1]
elif src_audio is not None:
pipe.load_models_to_device(self.onload_model_names)
src_audio = src_audio.unsqueeze(0) if src_audio.dim() == 2 else src_audio
src_audio = torch.clamp(src_audio, -1.0, 1.0)
src_latents = pipe.vae.encode(src_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
max_latent_length = src_latents.shape[1]
else:
max_latent_length = int(duration * pipe.sample_rate // 1920)
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
return {"context_latents": context_latents, "attention_mask": attention_mask}
pipe.load_models_to_device(["tokenizer_model"])
quantizer = pipe.tokenizer_model.tokenizer.quantizer
detokenizer = pipe.tokenizer_model.detokenizer
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
codes = quantizer.get_codes_from_indices(indices)
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
quantized = quantizer.project_out(quantized)
class AceStepUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("context_latents", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe, context_latents, seed, rand_device):
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
noise = pipe.scheduler.add_noise(context_latents[:, :, :src_latents_shape[-1]], noise, timestep=pipe.scheduler.timesteps[0])
return {"noise": noise}
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
"""Only for training."""
def __init__(self):
super().__init__(
input_params=("noise", "input_audio"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe, noise, input_audio):
if input_audio is None:
return {"latents": noise}
if pipe.scheduler.training:
pipe.load_models_to_device(self.onload_model_names)
input_audio, sample_rate = input_audio
input_audio = torch.clamp(input_audio, -1.0, 1.0)
if input_audio.dim() == 2:
input_audio = input_audio.unsqueeze(0)
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
input_latents = input_latents[:, :noise.shape[1]]
return {"input_latents": input_latents}
lm_hints = detokenizer(quantized).to(pipe.device)
return {"lm_hints": lm_hints}
def model_fn_ace_step(

View File

@@ -0,0 +1,36 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="cover",
src_audio=src_audio,
audio_cover_strength=0.6,
denoising_strength=0.9,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-cover.wav")