This commit is contained in:
mi804
2026-04-22 21:36:30 +08:00
parent f2e3427566
commit 1186379139
2 changed files with 150 additions and 89 deletions

View File

@@ -9,6 +9,8 @@ from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
import random
import math
import torch.nn.functional as F
from einops import rearrange
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
@@ -41,11 +43,11 @@ class AceStepPipeline(BasePipeline):
self.in_iteration_models = ("dit",)
self.units = [
AceStepUnit_TaskTypeChecker(),
AceStepUnit_PromptEmbedder(),
AceStepUnit_ReferenceAudioEmbedder(),
AceStepUnit_ConditionEmbedder(),
AceStepUnit_AudioCodeDecoder(),
AceStepUnit_ContextLatentBuilder(),
AceStepUnit_ConditionEmbedder(),
AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(),
]
@@ -100,7 +102,8 @@ class AceStepPipeline(BasePipeline):
reference_audios: List[torch.Tensor] = None,
# Source audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0,
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
audio_cover_strength: float = 1.0,
# Audio codes
audio_code_string: Optional[str] = None,
# Shape
@@ -121,7 +124,7 @@ class AceStepPipeline(BasePipeline):
progress_bar_cmd=tqdm,
):
# 1. Scheduler
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
# 2. 三字典输入
inputs_posi = {"prompt": prompt, "positive": True}
@@ -132,6 +135,7 @@ class AceStepPipeline(BasePipeline):
"task_type": task_type,
"reference_audios": reference_audios,
"src_audio": src_audio,
"audio_cover_strength": audio_cover_strength,
"audio_code_string": audio_code_string,
"duration": duration,
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
@@ -152,6 +156,7 @@ class AceStepPipeline(BasePipeline):
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
self.switch_noncover_condition(inputs_shared, inputs_posi, inputs_nega, progress_id)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
@@ -181,23 +186,28 @@ class AceStepPipeline(BasePipeline):
gain = target_amp / peak
return audio * gain
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0 or inputs_shared.get("shared_noncover", None) is None:
return
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
if progress_id >= cover_steps:
inputs_shared.update(inputs_shared.pop("shared_noncover", {}))
inputs_posi.update(inputs_shared.pop("posi_noncover", {}))
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega.update(inputs_shared.pop("nega_noncover", {}))
class AceStepUnit_TaskTypeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
input_params=("audio_code_string"),
input_params=("task_type",),
output_params=("task_type",),
)
def process(self, pipe, audio_code_string):
if pipe.scheduler.training:
return {"task_type": "text2music"}
if audio_code_string is not None:
task_type = "cover"
else:
task_type = "text2music"
return {"task_type": task_type}
def process(self, pipe, task_type):
assert task_type in ["text2music", "cover", "repaint"], f"Unsupported task_type: {task_type}"
return {}
class AceStepUnit_PromptEmbedder(PipelineUnit):
@@ -364,14 +374,34 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
pipe, inputs_posi["prompt"], True, inputs_shared["lyrics"], inputs_shared["duration"],
inputs_shared["bpm"], inputs_shared["keyscale"], inputs_shared["timesignature"],
inputs_shared["vocal_language"], "text2music")
encoder_hidden_states_noncover, encoder_attention_mask_noncover = pipe.conditioner(
**hidden_states_noncover,
reference_latents=inputs_shared.get("reference_latents", None),
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
)
duration = inputs_shared["context_latents"].shape[1] * 1920 / pipe.vae.sampling_rate
context_latents_noncover = AceStepUnit_ContextLatentBuilder().process(pipe, duration, None, None)["context_latents"]
inputs_shared["shared_noncover"] = {"context_latents": context_latents_noncover}
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
if inputs_shared["cfg_scale"] != 1.0:
inputs_shared["nega_noncover"] = {
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device),
"encoder_attention_mask": encoder_attention_mask_noncover,
}
return inputs_shared, inputs_posi, inputs_nega
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("duration", "src_audio", "lm_hints"),
input_params=("duration", "src_audio", "audio_code_string"),
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
onload_model_names=("vae", "tokenizer_model",),
)
def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
@@ -382,66 +412,13 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
tiled = pipe.silence_latent[0].repeat(repeats, 1)
return tiled[:length, :]
def process(self, pipe, duration, src_audio, lm_hints):
if lm_hints is not None:
max_latent_length = lm_hints.shape[1]
src_latents = lm_hints.clone()
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
# elif src_audio is not None:
# raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
else:
max_latent_length = duration * pipe.sample_rate // 1920
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
return {"context_latents": context_latents, "attention_mask": attention_mask}
class AceStepUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("context_latents", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe, context_latents, seed, rand_device):
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("noise", "input_audio"),
output_params=("latents", "input_latents"),
)
def process(self, pipe, noise, input_audio):
if input_audio is None:
return {"latents": noise}
if pipe.scheduler.training:
pipe.load_models_to_device(['vae'])
input_audio, sample_rate = input_audio
input_audio = torch.clamp(input_audio, -1.0, 1.0)
if input_audio.dim() == 2:
input_audio = input_audio.unsqueeze(0)
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
input_latents = input_latents[:, :noise.shape[1]]
return {"input_latents": input_latents}
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("audio_code_string",),
output_params=("lm_hints",),
onload_model_names=("tokenizer_model",),
)
def tokenize(self, tokenizer, x, silence_latent, pool_window_size):
if x.shape[1] % pool_window_size != 0:
pad_len = pool_window_size - (x.shape[1] % pool_window_size)
x = torch.cat([x, silence_latent[:1,:pad_len].repeat(x.shape[0],1,1)], dim=1)
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=pool_window_size)
quantized, indices = tokenizer(x)
return quantized
@staticmethod
def _parse_audio_code_string(code_str: str) -> list:
@@ -458,24 +435,72 @@ class AceStepUnit_AudioCodeDecoder(PipelineUnit):
raise ValueError(f"Invalid audio_code_string format: {e}")
return codes
def process(self, pipe, audio_code_string):
if audio_code_string is None or not audio_code_string.strip():
return {"lm_hints": None}
code_ids = self._parse_audio_code_string(audio_code_string)
if len(code_ids) == 0:
return {"lm_hints": None}
def process(self, pipe, duration, src_audio, audio_code_string):
# get src_latents from audio_code_string > src_audio > silence
if audio_code_string is not None:
pipe.load_models_to_device(self.onload_model_names)
code_ids = self._parse_audio_code_string(audio_code_string)
quantizer = pipe.tokenizer_model.tokenizer.quantizer
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
codes = quantizer.get_codes_from_indices(indices)
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
quantized = quantizer.project_out(quantized)
src_latents = pipe.tokenizer_model.detokenizer(quantized).to(pipe.device)
max_latent_length = src_latents.shape[1]
elif src_audio is not None:
pipe.load_models_to_device(self.onload_model_names)
src_audio = src_audio.unsqueeze(0) if src_audio.dim() == 2 else src_audio
src_audio = torch.clamp(src_audio, -1.0, 1.0)
src_latents = pipe.vae.encode(src_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
max_latent_length = src_latents.shape[1]
else:
max_latent_length = int(duration * pipe.sample_rate // 1920)
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
return {"context_latents": context_latents, "attention_mask": attention_mask}
pipe.load_models_to_device(["tokenizer_model"])
quantizer = pipe.tokenizer_model.tokenizer.quantizer
detokenizer = pipe.tokenizer_model.detokenizer
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
codes = quantizer.get_codes_from_indices(indices)
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
quantized = quantizer.project_out(quantized)
class AceStepUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("context_latents", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe, context_latents, seed, rand_device):
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
noise = pipe.scheduler.add_noise(context_latents[:, :, :src_latents_shape[-1]], noise, timestep=pipe.scheduler.timesteps[0])
return {"noise": noise}
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
"""Only for training."""
def __init__(self):
super().__init__(
input_params=("noise", "input_audio"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe, noise, input_audio):
if input_audio is None:
return {"latents": noise}
if pipe.scheduler.training:
pipe.load_models_to_device(self.onload_model_names)
input_audio, sample_rate = input_audio
input_audio = torch.clamp(input_audio, -1.0, 1.0)
if input_audio.dim() == 2:
input_audio = input_audio.unsqueeze(0)
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
input_latents = input_latents[:, :noise.shape[1]]
return {"input_latents": input_latents}
lm_hints = detokenizer(quantized).to(pipe.device)
return {"lm_hints": lm_hints}
def model_fn_ace_step(