|
|
|
|
@@ -3,8 +3,9 @@ ACE-Step Pipeline for DiffSynth-Studio.
|
|
|
|
|
|
|
|
|
|
Text-to-Music generation pipeline using ACE-Step 1.5 model.
|
|
|
|
|
"""
|
|
|
|
|
import re
|
|
|
|
|
import torch
|
|
|
|
|
from typing import Optional
|
|
|
|
|
from typing import Optional, Dict, Any, List, Tuple
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
from ..core.device.npu_compatible_device import get_device_type
|
|
|
|
|
@@ -16,6 +17,7 @@ from ..models.ace_step_dit import AceStepDiTModel
|
|
|
|
|
from ..models.ace_step_conditioner import AceStepConditionEncoder
|
|
|
|
|
from ..models.ace_step_text_encoder import AceStepTextEncoder
|
|
|
|
|
from ..models.ace_step_vae import AceStepVAE
|
|
|
|
|
from ..models.ace_step_tokenizer import AceStepTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AceStepPipeline(BasePipeline):
|
|
|
|
|
@@ -32,29 +34,18 @@ class AceStepPipeline(BasePipeline):
|
|
|
|
|
self.text_encoder: AceStepTextEncoder = None
|
|
|
|
|
self.conditioner: AceStepConditionEncoder = None
|
|
|
|
|
self.dit: AceStepDiTModel = None
|
|
|
|
|
self.vae = None # AutoencoderOobleck (diffusers) or AceStepVAE
|
|
|
|
|
self.vae: AceStepVAE = None
|
|
|
|
|
self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer)
|
|
|
|
|
|
|
|
|
|
# Unit chain order — 7 units total
|
|
|
|
|
#
|
|
|
|
|
# 1. ShapeChecker: duration → seq_len
|
|
|
|
|
# 2. PromptEmbedder: prompt/lyrics → text/lyric embeddings (shared for CFG)
|
|
|
|
|
# 3. SilenceLatentInitializer: seq_len → src_latents + chunk_masks
|
|
|
|
|
# 4. ContextLatentBuilder: src_latents + chunk_masks → context_latents (shared, same for CFG+)
|
|
|
|
|
# 5. ConditionEmbedder: text/lyric → encoder_hidden_states (separate for CFG+/-)
|
|
|
|
|
# 6. NoiseInitializer: context_latents → noise
|
|
|
|
|
# 7. InputAudioEmbedder: noise → latents
|
|
|
|
|
#
|
|
|
|
|
# ContextLatentBuilder runs before ConditionEmbedder so that
|
|
|
|
|
# context_latents is available for noise shape computation.
|
|
|
|
|
self.in_iteration_models = ("dit",)
|
|
|
|
|
self.units = [
|
|
|
|
|
AceStepUnit_ShapeChecker(),
|
|
|
|
|
AceStepUnit_PromptEmbedder(),
|
|
|
|
|
AceStepUnit_SilenceLatentInitializer(),
|
|
|
|
|
AceStepUnit_ContextLatentBuilder(),
|
|
|
|
|
AceStepUnit_ReferenceAudioEmbedder(),
|
|
|
|
|
AceStepUnit_ConditionEmbedder(),
|
|
|
|
|
AceStepUnit_ContextLatentBuilder(),
|
|
|
|
|
AceStepUnit_NoiseInitializer(),
|
|
|
|
|
AceStepUnit_InputAudioEmbedder(),
|
|
|
|
|
AceStepUnit_AudioCodeDecoder(),
|
|
|
|
|
]
|
|
|
|
|
self.model_fn = model_fn_ace_step
|
|
|
|
|
self.compilable_models = ["dit"]
|
|
|
|
|
@@ -66,7 +57,8 @@ class AceStepPipeline(BasePipeline):
|
|
|
|
|
torch_dtype: torch.dtype = torch.bfloat16,
|
|
|
|
|
device: str = get_device_type(),
|
|
|
|
|
model_configs: list[ModelConfig] = [],
|
|
|
|
|
text_tokenizer_config: ModelConfig = None,
|
|
|
|
|
text_tokenizer_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
|
|
|
|
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
|
|
|
|
vram_limit: float = None,
|
|
|
|
|
):
|
|
|
|
|
"""Load pipeline from pretrained checkpoints."""
|
|
|
|
|
@@ -77,11 +69,15 @@ class AceStepPipeline(BasePipeline):
|
|
|
|
|
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
|
|
|
|
|
pipe.dit = model_pool.fetch_model("ace_step_dit")
|
|
|
|
|
pipe.vae = model_pool.fetch_model("ace_step_vae")
|
|
|
|
|
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
|
|
|
|
|
|
|
|
|
|
if text_tokenizer_config is not None:
|
|
|
|
|
text_tokenizer_config.download_if_necessary()
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
|
|
|
|
|
if silence_latent_config is not None:
|
|
|
|
|
silence_latent_config.download_if_necessary()
|
|
|
|
|
pipe.silence_latent = torch.load(silence_latent_config.path, weights_only=True).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
|
|
|
|
|
|
|
|
# VRAM Management
|
|
|
|
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
|
|
|
|
@@ -97,9 +93,19 @@ class AceStepPipeline(BasePipeline):
|
|
|
|
|
# Lyrics
|
|
|
|
|
lyrics: str = "",
|
|
|
|
|
# Reference audio (optional, for timbre conditioning)
|
|
|
|
|
reference_audio = None,
|
|
|
|
|
reference_audios: List[torch.Tensor] = None,
|
|
|
|
|
# Src audio
|
|
|
|
|
src_audio: torch.Tensor = None,
|
|
|
|
|
denoising_strength: float = 1.0,
|
|
|
|
|
# Simple Mode: LLM-generated audio codes (optional)
|
|
|
|
|
audio_codes: str = None,
|
|
|
|
|
# Shape
|
|
|
|
|
duration: float = 60.0,
|
|
|
|
|
duration: int = 60,
|
|
|
|
|
# Audio Meta
|
|
|
|
|
bpm: Optional[int] = 100,
|
|
|
|
|
keyscale: Optional[str] = "B minor",
|
|
|
|
|
timesignature: Optional[str] = "4",
|
|
|
|
|
vocal_language: Optional[str] = 'zh',
|
|
|
|
|
# Randomness
|
|
|
|
|
seed: int = None,
|
|
|
|
|
rand_device: str = "cpu",
|
|
|
|
|
@@ -111,11 +117,7 @@ class AceStepPipeline(BasePipeline):
|
|
|
|
|
progress_bar_cmd=tqdm,
|
|
|
|
|
):
|
|
|
|
|
# 1. Scheduler
|
|
|
|
|
self.scheduler.set_timesteps(
|
|
|
|
|
num_inference_steps=num_inference_steps,
|
|
|
|
|
denoising_strength=1.0,
|
|
|
|
|
shift=shift,
|
|
|
|
|
)
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
|
|
|
|
|
|
|
|
|
|
# 2. 三字典输入
|
|
|
|
|
inputs_posi = {"prompt": prompt}
|
|
|
|
|
@@ -123,8 +125,11 @@ class AceStepPipeline(BasePipeline):
|
|
|
|
|
inputs_shared = {
|
|
|
|
|
"cfg_scale": cfg_scale,
|
|
|
|
|
"lyrics": lyrics,
|
|
|
|
|
"reference_audio": reference_audio,
|
|
|
|
|
"reference_audios": reference_audios,
|
|
|
|
|
"src_audio": src_audio,
|
|
|
|
|
"audio_codes": audio_codes,
|
|
|
|
|
"duration": duration,
|
|
|
|
|
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
|
|
|
|
|
"seed": seed,
|
|
|
|
|
"rand_device": rand_device,
|
|
|
|
|
"num_inference_steps": num_inference_steps,
|
|
|
|
|
@@ -159,6 +164,10 @@ class AceStepPipeline(BasePipeline):
|
|
|
|
|
# VAE returns OobleckDecoderOutput with .sample attribute
|
|
|
|
|
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
|
|
|
|
|
audio = self.output_audio_format_check(audio_output)
|
|
|
|
|
|
|
|
|
|
# Peak normalization to match target library behavior
|
|
|
|
|
audio = self.normalize_audio(audio, target_db=-1.0)
|
|
|
|
|
|
|
|
|
|
self.load_models_to_device([])
|
|
|
|
|
return audio
|
|
|
|
|
|
|
|
|
|
@@ -172,294 +181,303 @@ class AceStepPipeline(BasePipeline):
|
|
|
|
|
audio_output = audio_output.squeeze(0)
|
|
|
|
|
return audio_output.float()
|
|
|
|
|
|
|
|
|
|
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
|
|
|
|
|
"""Apply peak normalization to audio data, matching target library behavior.
|
|
|
|
|
|
|
|
|
|
class AceStepUnit_ShapeChecker(PipelineUnit):
|
|
|
|
|
"""Check and compute sequence length from duration."""
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("duration",),
|
|
|
|
|
output_params=("duration", "seq_len"),
|
|
|
|
|
)
|
|
|
|
|
Target library reference: `acestep/audio_utils.py:normalize_audio()`
|
|
|
|
|
peak = max(abs(audio))
|
|
|
|
|
gain = 10^(target_db/20) / peak
|
|
|
|
|
audio = audio * gain
|
|
|
|
|
|
|
|
|
|
def process(self, pipe, duration):
|
|
|
|
|
# ACE-Step: 25 Hz latent rate
|
|
|
|
|
seq_len = int(duration * 25)
|
|
|
|
|
return {"duration": duration, "seq_len": seq_len}
|
|
|
|
|
Args:
|
|
|
|
|
audio: Audio tensor [C, T]
|
|
|
|
|
target_db: Target peak level in dB (default: -1.0)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Normalized audio tensor
|
|
|
|
|
"""
|
|
|
|
|
peak = torch.max(torch.abs(audio))
|
|
|
|
|
if peak < 1e-6:
|
|
|
|
|
return audio
|
|
|
|
|
target_amp = 10 ** (target_db / 20.0)
|
|
|
|
|
gain = target_amp / peak
|
|
|
|
|
return audio * gain
|
|
|
|
|
|
|
|
|
|
class AceStepUnit_PromptEmbedder(PipelineUnit):
|
|
|
|
|
"""Encode prompt and lyrics using Qwen3-Embedding.
|
|
|
|
|
SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
|
|
|
|
|
INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
|
|
|
|
|
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
|
|
|
|
|
|
|
|
|
|
Uses seperate_cfg=True to read prompt from inputs_posi (not inputs_shared).
|
|
|
|
|
The negative condition uses null_condition_emb (handled by ConditionEmbedder),
|
|
|
|
|
so negative text encoding is not needed here.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
seperate_cfg=True,
|
|
|
|
|
input_params_posi={"prompt": "prompt"},
|
|
|
|
|
input_params_nega={},
|
|
|
|
|
input_params=("lyrics",),
|
|
|
|
|
input_params_nega={"prompt": "prompt"},
|
|
|
|
|
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language"),
|
|
|
|
|
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
|
|
|
|
|
onload_model_names=("text_encoder",)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _encode_text(self, pipe, text):
|
|
|
|
|
def _encode_text(self, pipe, text, max_length=256):
|
|
|
|
|
"""Encode text using Qwen3-Embedding → [B, T, 1024]."""
|
|
|
|
|
if pipe.tokenizer is None:
|
|
|
|
|
return None, None
|
|
|
|
|
text_inputs = pipe.tokenizer(
|
|
|
|
|
text,
|
|
|
|
|
padding="max_length",
|
|
|
|
|
max_length=512,
|
|
|
|
|
max_length=max_length,
|
|
|
|
|
truncation=True,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
)
|
|
|
|
|
input_ids = text_inputs.input_ids.to(pipe.device)
|
|
|
|
|
attention_mask = text_inputs.attention_mask.to(pipe.device)
|
|
|
|
|
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
|
|
|
|
|
hidden_states = pipe.text_encoder(input_ids, attention_mask)
|
|
|
|
|
return hidden_states, attention_mask
|
|
|
|
|
|
|
|
|
|
def process(self, pipe, prompt, lyrics, negative_prompt=None):
|
|
|
|
|
def _encode_lyrics(self, pipe, lyric_text, max_length=2048):
|
|
|
|
|
text_inputs = pipe.tokenizer(
|
|
|
|
|
lyric_text,
|
|
|
|
|
max_length=max_length,
|
|
|
|
|
truncation=True,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
)
|
|
|
|
|
input_ids = text_inputs.input_ids.to(pipe.device)
|
|
|
|
|
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
|
|
|
|
|
hidden_states = pipe.text_encoder.model.embed_tokens(input_ids)
|
|
|
|
|
return hidden_states, attention_mask
|
|
|
|
|
|
|
|
|
|
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
|
|
|
|
|
bpm = meta_dict.get("bpm", "N/A")
|
|
|
|
|
timesignature = meta_dict.get("timesignature", "N/A")
|
|
|
|
|
keyscale = meta_dict.get("keyscale", "N/A")
|
|
|
|
|
duration = meta_dict.get("duration", 30)
|
|
|
|
|
duration = f"{int(duration)} seconds"
|
|
|
|
|
return (
|
|
|
|
|
f"- bpm: {bpm}\n"
|
|
|
|
|
f"- timesignature: {timesignature}\n"
|
|
|
|
|
f"- keyscale: {keyscale}\n"
|
|
|
|
|
f"- duration: {duration}\n"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def process(self, pipe, prompt, lyrics, duration, bpm, keyscale, timesignature, vocal_language):
|
|
|
|
|
pipe.load_models_to_device(['text_encoder'])
|
|
|
|
|
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
|
|
|
|
|
prompt = self.SFT_GEN_PROMPT.format(self.INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
|
|
|
|
|
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
|
|
|
|
|
|
|
|
|
|
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt)
|
|
|
|
|
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
|
|
|
|
|
lyric_hidden_states, lyric_attention_mask = self._encode_lyrics(pipe, lyric_text, max_length=2048)
|
|
|
|
|
|
|
|
|
|
# Lyrics encoding — use empty string if not provided
|
|
|
|
|
lyric_text = lyrics if lyrics else ""
|
|
|
|
|
lyric_hidden_states, lyric_attention_mask = self._encode_text(pipe, lyric_text)
|
|
|
|
|
|
|
|
|
|
if text_hidden_states is not None and lyric_hidden_states is not None:
|
|
|
|
|
return {
|
|
|
|
|
"text_hidden_states": text_hidden_states,
|
|
|
|
|
"text_attention_mask": text_attention_mask,
|
|
|
|
|
"lyric_hidden_states": lyric_hidden_states,
|
|
|
|
|
"lyric_attention_mask": lyric_attention_mask,
|
|
|
|
|
}
|
|
|
|
|
return {}
|
|
|
|
|
# TODO: remove this
|
|
|
|
|
newtext = prompt + "\n\n" + lyric_text
|
|
|
|
|
return {
|
|
|
|
|
"text_hidden_states": text_hidden_states,
|
|
|
|
|
"text_attention_mask": text_attention_mask,
|
|
|
|
|
"lyric_hidden_states": lyric_hidden_states,
|
|
|
|
|
"lyric_attention_mask": lyric_attention_mask,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AceStepUnit_SilenceLatentInitializer(PipelineUnit):
|
|
|
|
|
"""Generate silence latent (all zeros) and chunk_masks for text2music.
|
|
|
|
|
|
|
|
|
|
Target library reference: `prepare_condition()` line 1698-1699:
|
|
|
|
|
context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
|
|
|
|
|
|
|
|
|
|
For text2music mode:
|
|
|
|
|
- src_latents = zeros [B, T, 64] (VAE latent dimension)
|
|
|
|
|
- chunk_masks = ones [B, T, 64] (full visibility mask for text2music)
|
|
|
|
|
- context_latents = [B, T, 128] (concat of src_latents + chunk_masks)
|
|
|
|
|
"""
|
|
|
|
|
class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("seq_len",),
|
|
|
|
|
output_params=("silence_latent", "src_latents", "chunk_masks"),
|
|
|
|
|
input_params=("reference_audios",),
|
|
|
|
|
output_params=("reference_latents", "refer_audio_order_mask"),
|
|
|
|
|
onload_model_names=("vae",)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def process(self, pipe, seq_len):
|
|
|
|
|
# silence_latent shape: [B, T, 64] — 64 is the VAE latent dimension
|
|
|
|
|
silence_latent = torch.zeros(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype)
|
|
|
|
|
# For text2music: src_latents = silence_latent
|
|
|
|
|
src_latents = silence_latent.clone()
|
|
|
|
|
def process(self, pipe, reference_audios):
|
|
|
|
|
pipe.load_models_to_device(['vae'])
|
|
|
|
|
if reference_audios is not None and len(reference_audios) > 0:
|
|
|
|
|
# TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
|
|
|
|
|
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
|
|
|
|
|
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
|
|
|
|
|
|
|
|
|
|
# chunk_masks: [B, T, 64] of ones (same shape as src_latents)
|
|
|
|
|
# In text2music mode (is_covers=0), chunk_masks are all 1.0
|
|
|
|
|
# This matches the target library's behavior at line 1699
|
|
|
|
|
chunk_masks = torch.ones(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype)
|
|
|
|
|
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""Infer packed reference-audio latents and order mask."""
|
|
|
|
|
refer_audio_order_mask = []
|
|
|
|
|
refer_audio_latents = []
|
|
|
|
|
|
|
|
|
|
return {"silence_latent": silence_latent, "src_latents": src_latents, "chunk_masks": chunk_masks}
|
|
|
|
|
def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
if not isinstance(a, torch.Tensor):
|
|
|
|
|
raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}")
|
|
|
|
|
if a.dim() == 3 and a.shape[0] == 1:
|
|
|
|
|
a = a.squeeze(0)
|
|
|
|
|
if a.dim() == 1:
|
|
|
|
|
a = a.unsqueeze(0)
|
|
|
|
|
if a.dim() != 2:
|
|
|
|
|
raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}")
|
|
|
|
|
if a.shape[0] == 1:
|
|
|
|
|
a = torch.cat([a, a], dim=0)
|
|
|
|
|
return a[:2]
|
|
|
|
|
|
|
|
|
|
def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
if z.dim() == 4 and z.shape[0] == 1:
|
|
|
|
|
z = z.squeeze(0)
|
|
|
|
|
if z.dim() == 2:
|
|
|
|
|
z = z.unsqueeze(0)
|
|
|
|
|
return z
|
|
|
|
|
|
|
|
|
|
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
|
|
|
|
"""Build context_latents from src_latents and chunk_masks.
|
|
|
|
|
refer_encode_cache: Dict[int, torch.Tensor] = {}
|
|
|
|
|
for batch_idx, refer_audios in enumerate(refer_audioss):
|
|
|
|
|
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
|
|
|
|
|
refer_audio_latent = _ensure_latent_3d(pipe.silence_latent[:, :750, :])
|
|
|
|
|
refer_audio_latents.append(refer_audio_latent)
|
|
|
|
|
refer_audio_order_mask.append(batch_idx)
|
|
|
|
|
else:
|
|
|
|
|
# TODO: check
|
|
|
|
|
for refer_audio in refer_audios:
|
|
|
|
|
cache_key = refer_audio.data_ptr()
|
|
|
|
|
if cache_key in refer_encode_cache:
|
|
|
|
|
refer_audio_latent = refer_encode_cache[cache_key].clone()
|
|
|
|
|
else:
|
|
|
|
|
refer_audio = _normalize_audio_2d(refer_audio)
|
|
|
|
|
refer_audio_latent = pipe.vae.encode(refer_audio)
|
|
|
|
|
refer_audio_latent = refer_audio_latent.to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
|
|
|
if refer_audio_latent.dim() == 2:
|
|
|
|
|
refer_audio_latent = refer_audio_latent.unsqueeze(0)
|
|
|
|
|
refer_audio_latent = _ensure_latent_3d(refer_audio_latent.transpose(1, 2))
|
|
|
|
|
refer_encode_cache[cache_key] = refer_audio_latent
|
|
|
|
|
refer_audio_latents.append(refer_audio_latent)
|
|
|
|
|
refer_audio_order_mask.append(batch_idx)
|
|
|
|
|
|
|
|
|
|
Target library reference: `prepare_condition()` line 1699:
|
|
|
|
|
context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
|
|
|
|
|
|
|
|
|
|
context_latents is the SAME for positive and negative CFG paths
|
|
|
|
|
(it comes from src_latents + chunk_masks, not from text encoding).
|
|
|
|
|
So this is a普通模式 Unit — outputs go to inputs_shared.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("src_latents", "chunk_masks"),
|
|
|
|
|
output_params=("context_latents", "attention_mask"),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def process(self, pipe, src_latents, chunk_masks):
|
|
|
|
|
# context_latents: cat([src_latents, chunk_masks], dim=-1) → [B, T, 128]
|
|
|
|
|
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
|
|
|
|
|
|
|
|
|
# attention_mask for the DiT: ones [B, T]
|
|
|
|
|
# The target library uses this for cross-attention with context_latents
|
|
|
|
|
attention_mask = torch.ones(src_latents.shape[0], src_latents.shape[1],
|
|
|
|
|
device=pipe.device, dtype=pipe.torch_dtype)
|
|
|
|
|
|
|
|
|
|
return {"context_latents": context_latents, "attention_mask": attention_mask}
|
|
|
|
|
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
|
|
|
|
|
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
|
|
|
|
|
return refer_audio_latents, refer_audio_order_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
|
|
|
|
"""Generate encoder_hidden_states via ACEStepConditioner.
|
|
|
|
|
|
|
|
|
|
Target library reference: `prepare_condition()` line 1674-1681:
|
|
|
|
|
encoder_hidden_states, encoder_attention_mask = self.encoder(...)
|
|
|
|
|
|
|
|
|
|
Uses seperate_cfg mode:
|
|
|
|
|
- Positive: encode with full condition (text + lyrics + reference audio)
|
|
|
|
|
- Negative: replace text with null_condition_emb, keep lyrics/timbre same
|
|
|
|
|
|
|
|
|
|
context_latents is handled by ContextLatentBuilder (普通模式), not here.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
seperate_cfg=True,
|
|
|
|
|
input_params_posi={
|
|
|
|
|
"text_hidden_states": "text_hidden_states",
|
|
|
|
|
"text_attention_mask": "text_attention_mask",
|
|
|
|
|
"lyric_hidden_states": "lyric_hidden_states",
|
|
|
|
|
"lyric_attention_mask": "lyric_attention_mask",
|
|
|
|
|
"reference_audio": "reference_audio",
|
|
|
|
|
"refer_audio_order_mask": "refer_audio_order_mask",
|
|
|
|
|
},
|
|
|
|
|
input_params_nega={},
|
|
|
|
|
input_params=("cfg_scale",),
|
|
|
|
|
output_params=(
|
|
|
|
|
"encoder_hidden_states", "encoder_attention_mask",
|
|
|
|
|
"negative_encoder_hidden_states", "negative_encoder_attention_mask",
|
|
|
|
|
),
|
|
|
|
|
onload_model_names=("conditioner",)
|
|
|
|
|
input_params_posi={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
|
|
|
|
|
input_params_nega={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
|
|
|
|
|
input_params=("reference_latents", "refer_audio_order_mask"),
|
|
|
|
|
output_params=("encoder_hidden_states", "encoder_attention_mask"),
|
|
|
|
|
onload_model_names=("conditioner",),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _prepare_condition(self, pipe, text_hidden_states, text_attention_mask,
|
|
|
|
|
lyric_hidden_states, lyric_attention_mask,
|
|
|
|
|
refer_audio_acoustic_hidden_states_packed=None,
|
|
|
|
|
refer_audio_order_mask=None):
|
|
|
|
|
"""Call ACEStepConditioner forward to produce encoder_hidden_states."""
|
|
|
|
|
def process(self, pipe, text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, reference_latents, refer_audio_order_mask):
|
|
|
|
|
pipe.load_models_to_device(['conditioner'])
|
|
|
|
|
|
|
|
|
|
# Handle reference audio
|
|
|
|
|
if refer_audio_acoustic_hidden_states_packed is None:
|
|
|
|
|
# No reference audio: create 2D packed zeros [N=1, d=64]
|
|
|
|
|
# TimbreEncoder.unpack expects [N, d], not [B, T, d]
|
|
|
|
|
refer_audio_acoustic_hidden_states_packed = torch.zeros(
|
|
|
|
|
1, 64, device=pipe.device, dtype=pipe.torch_dtype
|
|
|
|
|
)
|
|
|
|
|
refer_audio_order_mask = torch.LongTensor([0]).to(pipe.device)
|
|
|
|
|
|
|
|
|
|
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
|
|
|
|
|
text_hidden_states=text_hidden_states,
|
|
|
|
|
text_attention_mask=text_attention_mask,
|
|
|
|
|
lyric_hidden_states=lyric_hidden_states,
|
|
|
|
|
lyric_attention_mask=lyric_attention_mask,
|
|
|
|
|
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
|
|
|
|
|
reference_latents=reference_latents,
|
|
|
|
|
refer_audio_order_mask=refer_audio_order_mask,
|
|
|
|
|
)
|
|
|
|
|
return {"encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask}
|
|
|
|
|
|
|
|
|
|
return encoder_hidden_states, encoder_attention_mask
|
|
|
|
|
|
|
|
|
|
def _prepare_negative_condition(self, pipe, lyric_hidden_states, lyric_attention_mask,
|
|
|
|
|
refer_audio_acoustic_hidden_states_packed=None,
|
|
|
|
|
refer_audio_order_mask=None):
|
|
|
|
|
"""Generate negative condition using null_condition_emb."""
|
|
|
|
|
if pipe.conditioner is None or not hasattr(pipe.conditioner, 'null_condition_emb'):
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
null_emb = pipe.conditioner.null_condition_emb # [1, 1, hidden_size]
|
|
|
|
|
bsz = 1
|
|
|
|
|
if lyric_hidden_states is not None:
|
|
|
|
|
bsz = lyric_hidden_states.shape[0]
|
|
|
|
|
null_hidden_states = null_emb.expand(bsz, -1, -1)
|
|
|
|
|
null_attn_mask = torch.ones(bsz, 1, device=pipe.device, dtype=pipe.torch_dtype)
|
|
|
|
|
|
|
|
|
|
# For negative: use null_condition_emb as text, keep lyrics and timbre
|
|
|
|
|
neg_encoder_hidden_states, neg_encoder_attention_mask = pipe.conditioner(
|
|
|
|
|
text_hidden_states=null_hidden_states,
|
|
|
|
|
text_attention_mask=null_attn_mask,
|
|
|
|
|
lyric_hidden_states=lyric_hidden_states,
|
|
|
|
|
lyric_attention_mask=lyric_attention_mask,
|
|
|
|
|
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
|
|
|
|
|
refer_audio_order_mask=refer_audio_order_mask,
|
|
|
|
|
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("duration", "src_audio"),
|
|
|
|
|
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return neg_encoder_hidden_states, neg_encoder_attention_mask
|
|
|
|
|
def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
|
|
|
|
|
available = pipe.silence_latent.shape[1]
|
|
|
|
|
if length <= available:
|
|
|
|
|
return pipe.silence_latent[0, :length, :]
|
|
|
|
|
repeats = (length + available - 1) // available
|
|
|
|
|
tiled = pipe.silence_latent[0].repeat(repeats, 1)
|
|
|
|
|
return tiled[:length, :]
|
|
|
|
|
|
|
|
|
|
def process(self, pipe, text_hidden_states, text_attention_mask,
|
|
|
|
|
lyric_hidden_states, lyric_attention_mask,
|
|
|
|
|
reference_audio=None, refer_audio_order_mask=None,
|
|
|
|
|
negative_prompt=None, cfg_scale=1.0):
|
|
|
|
|
|
|
|
|
|
# Positive condition
|
|
|
|
|
pos_enc_hs, pos_enc_mask = self._prepare_condition(
|
|
|
|
|
pipe, text_hidden_states, text_attention_mask,
|
|
|
|
|
lyric_hidden_states, lyric_attention_mask,
|
|
|
|
|
None, refer_audio_order_mask,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Negative condition: only needed when CFG is active (cfg_scale > 1.0)
|
|
|
|
|
# For cfg_scale=1.0 (turbo), skip to avoid null_condition_emb dimension mismatch
|
|
|
|
|
result = {
|
|
|
|
|
"encoder_hidden_states": pos_enc_hs,
|
|
|
|
|
"encoder_attention_mask": pos_enc_mask,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if cfg_scale > 1.0:
|
|
|
|
|
neg_enc_hs, neg_enc_mask = self._prepare_negative_condition(
|
|
|
|
|
pipe, lyric_hidden_states, lyric_attention_mask,
|
|
|
|
|
None, refer_audio_order_mask,
|
|
|
|
|
)
|
|
|
|
|
if neg_enc_hs is not None:
|
|
|
|
|
result["negative_encoder_hidden_states"] = neg_enc_hs
|
|
|
|
|
result["negative_encoder_attention_mask"] = neg_enc_mask
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
def process(self, pipe, duration, src_audio):
|
|
|
|
|
if src_audio is not None:
|
|
|
|
|
raise NotImplementedError("Src audio conditioning is not implemented yet. Please set src_audio to None.")
|
|
|
|
|
else:
|
|
|
|
|
max_latent_length = duration * pipe.sample_rate // 1920
|
|
|
|
|
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
|
|
|
|
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
|
|
|
|
|
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
|
|
|
|
|
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
|
|
|
|
return {"context_latents": context_latents, "attention_mask": attention_mask}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AceStepUnit_NoiseInitializer(PipelineUnit):
|
|
|
|
|
"""Generate initial noise tensor.
|
|
|
|
|
|
|
|
|
|
Target library reference: `prepare_noise()` line 1781-1818:
|
|
|
|
|
src_latents_shape = (bsz, context_latents.shape[1], context_latents.shape[-1] // 2)
|
|
|
|
|
|
|
|
|
|
Noise shape = [B, T, context_latents.shape[-1] // 2] = [B, T, 128 // 2] = [B, T, 64]
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("seed", "seq_len", "rand_device", "context_latents"),
|
|
|
|
|
input_params=("context_latents", "seed", "rand_device"),
|
|
|
|
|
output_params=("noise",),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def process(self, pipe, seed, seq_len, rand_device, context_latents):
|
|
|
|
|
# Noise shape: [B, T, context_latents.shape[-1] // 2]
|
|
|
|
|
# context_latents = [B, T, 128] → noise = [B, T, 64]
|
|
|
|
|
# This matches the target library's prepare_noise() at line 1796
|
|
|
|
|
noise_shape = (context_latents.shape[0], context_latents.shape[1],
|
|
|
|
|
context_latents.shape[-1] // 2)
|
|
|
|
|
noise = pipe.generate_noise(
|
|
|
|
|
noise_shape,
|
|
|
|
|
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
|
|
|
|
|
)
|
|
|
|
|
def process(self, pipe, context_latents, seed, rand_device):
|
|
|
|
|
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
|
|
|
|
|
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
|
|
|
|
return {"noise": noise}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
|
|
|
|
"""Set up latents for denoise loop.
|
|
|
|
|
|
|
|
|
|
For text2music (no input audio): latents = noise, input_latents = None.
|
|
|
|
|
|
|
|
|
|
Target library reference: `generate_audio()` line 1972:
|
|
|
|
|
xt = noise (when cover_noise_strength == 0)
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("noise",),
|
|
|
|
|
input_params=("noise", "input_audio"),
|
|
|
|
|
output_params=("latents", "input_latents"),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def process(self, pipe, noise):
|
|
|
|
|
# For text2music: start from pure noise
|
|
|
|
|
def process(self, pipe, noise, input_audio):
|
|
|
|
|
if input_audio is None:
|
|
|
|
|
return {"latents": noise}
|
|
|
|
|
# TODO: support for train
|
|
|
|
|
return {"latents": noise, "input_latents": None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("audio_codes", "seq_len", "silence_latent"),
|
|
|
|
|
output_params=("lm_hints_25Hz",),
|
|
|
|
|
onload_model_names=("tokenizer_model",),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _parse_audio_code_string(code_str: str) -> list:
|
|
|
|
|
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
|
|
|
|
|
if not code_str:
|
|
|
|
|
return []
|
|
|
|
|
codes = []
|
|
|
|
|
max_audio_code = 63999
|
|
|
|
|
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
|
|
|
|
|
code_value = int(x)
|
|
|
|
|
codes.append(max(0, min(code_value, max_audio_code)))
|
|
|
|
|
return codes
|
|
|
|
|
|
|
|
|
|
def process(self, pipe, audio_codes, seq_len, silence_latent):
|
|
|
|
|
if audio_codes is None or not audio_codes.strip():
|
|
|
|
|
return {"lm_hints_25Hz": None}
|
|
|
|
|
|
|
|
|
|
code_ids = self._parse_audio_code_string(audio_codes)
|
|
|
|
|
if len(code_ids) == 0:
|
|
|
|
|
return {"lm_hints_25Hz": None}
|
|
|
|
|
|
|
|
|
|
pipe.load_models_to_device(["tokenizer_model"])
|
|
|
|
|
|
|
|
|
|
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
|
|
|
|
detokenizer = pipe.tokenizer_model.detokenizer
|
|
|
|
|
|
|
|
|
|
indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
|
|
|
|
|
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
|
|
|
|
|
|
|
|
|
|
quantized = quantizer.get_output_from_indices(indices) # [1, N, 2048]
|
|
|
|
|
if quantized.dtype != pipe.torch_dtype:
|
|
|
|
|
quantized = quantized.to(pipe.torch_dtype)
|
|
|
|
|
|
|
|
|
|
lm_hints = detokenizer(quantized) # [1, N*5, 64]
|
|
|
|
|
|
|
|
|
|
# Pad or truncate to seq_len
|
|
|
|
|
current_len = lm_hints.shape[1]
|
|
|
|
|
if current_len < seq_len:
|
|
|
|
|
pad_len = seq_len - current_len
|
|
|
|
|
pad = silence_latent[:, :pad_len, :]
|
|
|
|
|
lm_hints = torch.cat([lm_hints, pad], dim=1)
|
|
|
|
|
elif current_len > seq_len:
|
|
|
|
|
lm_hints = lm_hints[:, :seq_len, :]
|
|
|
|
|
|
|
|
|
|
return {"lm_hints_25Hz": lm_hints}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def model_fn_ace_step(
|
|
|
|
|
dit: AceStepDiTModel,
|
|
|
|
|
latents=None,
|
|
|
|
|
@@ -468,49 +486,11 @@ def model_fn_ace_step(
|
|
|
|
|
encoder_attention_mask=None,
|
|
|
|
|
context_latents=None,
|
|
|
|
|
attention_mask=None,
|
|
|
|
|
past_key_values=None,
|
|
|
|
|
negative_encoder_hidden_states=None,
|
|
|
|
|
negative_encoder_attention_mask=None,
|
|
|
|
|
negative_context_latents=None,
|
|
|
|
|
use_gradient_checkpointing=False,
|
|
|
|
|
use_gradient_checkpointing_offload=False,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
"""Model function for ACE-Step DiT forward.
|
|
|
|
|
|
|
|
|
|
Timestep is already in [0, 1] range — no scaling needed.
|
|
|
|
|
|
|
|
|
|
Target library reference: `generate_audio()` line 2009-2020:
|
|
|
|
|
decoder_outputs = self.decoder(
|
|
|
|
|
hidden_states=x, timestep=t_curr_tensor, timestep_r=t_curr_tensor,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
|
|
|
context_latents=context_latents,
|
|
|
|
|
use_cache=True, past_key_values=past_key_values,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
dit: AceStepDiTModel
|
|
|
|
|
latents: [B, T, 64] noise/latent tensor (same shape as src_latents)
|
|
|
|
|
timestep: scalar tensor in [0, 1]
|
|
|
|
|
encoder_hidden_states: [B, T_text, 2048] condition from Conditioner
|
|
|
|
|
(positive or negative depending on CFG pass — the cfg_guided_model_fn
|
|
|
|
|
passes inputs_posi for positive, inputs_nega for negative)
|
|
|
|
|
encoder_attention_mask: [B, T_text]
|
|
|
|
|
context_latents: [B, T, 128] = cat([src_latents, chunk_masks], dim=-1)
|
|
|
|
|
(same for both CFG+/- paths in text2music mode)
|
|
|
|
|
attention_mask: [B, T] ones mask for DiT
|
|
|
|
|
past_key_values: EncoderDecoderCache for KV caching
|
|
|
|
|
|
|
|
|
|
The DiT internally concatenates: cat([context_latents, latents], dim=-1) = [B, T, 192]
|
|
|
|
|
as the actual input (128 + 64 = 192 channels).
|
|
|
|
|
"""
|
|
|
|
|
# ACE-Step uses timestep directly in [0, 1] range — no /1000 scaling
|
|
|
|
|
timestep = timestep.squeeze()
|
|
|
|
|
|
|
|
|
|
# Expand timestep to match batch size
|
|
|
|
|
bsz = latents.shape[0]
|
|
|
|
|
timestep = timestep.expand(bsz)
|
|
|
|
|
|
|
|
|
|
timestep = timestep.unsqueeze(0)
|
|
|
|
|
decoder_outputs = dit(
|
|
|
|
|
hidden_states=latents,
|
|
|
|
|
timestep=timestep,
|
|
|
|
|
@@ -519,9 +499,5 @@ def model_fn_ace_step(
|
|
|
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
|
|
|
context_latents=context_latents,
|
|
|
|
|
use_cache=True,
|
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Return velocity prediction (first element of decoder_outputs)
|
|
|
|
|
return decoder_outputs[0]
|
|
|
|
|
)[0]
|
|
|
|
|
return decoder_outputs
|
|
|
|
|
|