acestep t2m

This commit is contained in:
mi804
2026-04-21 13:16:15 +08:00
parent a604d76339
commit 9d09e0431c
9 changed files with 300 additions and 377 deletions

View File

@@ -157,10 +157,10 @@ class FlowMatchScheduler():
"""
num_train_timesteps = 1000
sigma_start = denoising_strength
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps)
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
if shift is not None and shift != 1.0:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas # ACE-Step uses [0, 1] range directly
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod

View File

@@ -540,17 +540,9 @@ class AceStepTimbreEncoder(nn.Module):
) -> BaseModelOutput:
inputs_embeds = refer_audio_acoustic_hidden_states_packed
inputs_embeds = self.embed_tokens(inputs_embeds)
# Handle 2D (packed) or 3D (batched) input
is_packed = inputs_embeds.dim() == 2
if is_packed:
seq_len = inputs_embeds.shape[0]
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
inputs_embeds = inputs_embeds.unsqueeze(0)
else:
seq_len = inputs_embeds.shape[1]
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
seq_len = inputs_embeds.shape[1]
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
dtype = inputs_embeds.dtype
device = inputs_embeds.device
@@ -586,9 +578,8 @@ class AceStepTimbreEncoder(nn.Module):
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states[:, 0, :]
# For packed input: reshape [1, T, D] -> [T, D] for unpacking
if is_packed:
hidden_states = hidden_states.squeeze(0)
timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
return timbre_embs_unpack, timbre_embs_mask
@@ -686,7 +677,7 @@ class AceStepConditionEncoder(nn.Module):
text_attention_mask: Optional[torch.Tensor] = None,
lyric_hidden_states: Optional[torch.LongTensor] = None,
lyric_attention_mask: Optional[torch.Tensor] = None,
refer_audio_acoustic_hidden_states_packed: Optional[torch.Tensor] = None,
reference_latents: Optional[torch.Tensor] = None,
refer_audio_order_mask: Optional[torch.LongTensor] = None,
):
text_hidden_states = self.text_projector(text_hidden_states)
@@ -695,11 +686,7 @@ class AceStepConditionEncoder(nn.Module):
attention_mask=lyric_attention_mask,
)
lyric_hidden_states = lyric_encoder_outputs.last_hidden_state
timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(
refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask
)
timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(reference_latents, refer_audio_order_mask)
encoder_hidden_states, encoder_attention_mask = pack_sequences(
lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask
)

View File

@@ -165,7 +165,7 @@ class TimestepEmbedding(nn.Module):
self,
in_channels: int,
time_embed_dim: int,
scale: float = 1000,
scale: float = 1,
):
super().__init__()
@@ -711,7 +711,7 @@ class AceStepDiTModel(nn.Module):
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
context_latents: torch.Tensor,
use_cache: Optional[bool] = None,
use_cache: Optional[bool] = False,
past_key_values: Optional[EncoderDecoderCache] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,

View File

@@ -2,17 +2,6 @@ import torch
class AceStepTextEncoder(torch.nn.Module):
"""
Text encoder for ACE-Step using Qwen3-Embedding-0.6B.
Converts text/lyric tokens to hidden state embeddings that are
further processed by the ACE-Step ConditionEncoder.
Wraps a Qwen3Model transformers model. Config is manually
constructed, and model weights are loaded via DiffSynth's
standard mechanism from safetensors files.
"""
def __init__(
self,
):
@@ -49,8 +38,6 @@ class AceStepTextEncoder(torch.nn.Module):
)
self.model = Qwen3Model(config)
self.config = config
self.hidden_size = config.hidden_size
@torch.no_grad()
def forward(
@@ -58,23 +45,9 @@ class AceStepTextEncoder(torch.nn.Module):
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
):
"""
Encode text/lyric tokens to hidden states.
Args:
input_ids: [B, T] token IDs
attention_mask: [B, T] attention mask
Returns:
last_hidden_state: [B, T, hidden_size]
"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
return outputs.last_hidden_state
def to(self, *args, **kwargs):
self.model.to(*args, **kwargs)
return self

View File

@@ -226,6 +226,7 @@ class AceStepVAE(nn.Module):
upsampling_ratios=upsampling_ratios,
channel_multiples=channel_multiples,
)
self.sampling_rate = sampling_rate
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Audio waveform [B, audio_channels, T] → latent [B, encoder_hidden_size, T']."""

View File

@@ -3,8 +3,9 @@ ACE-Step Pipeline for DiffSynth-Studio.
Text-to-Music generation pipeline using ACE-Step 1.5 model.
"""
import re
import torch
from typing import Optional
from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
from ..core.device.npu_compatible_device import get_device_type
@@ -16,6 +17,7 @@ from ..models.ace_step_dit import AceStepDiTModel
from ..models.ace_step_conditioner import AceStepConditionEncoder
from ..models.ace_step_text_encoder import AceStepTextEncoder
from ..models.ace_step_vae import AceStepVAE
from ..models.ace_step_tokenizer import AceStepTokenizer
class AceStepPipeline(BasePipeline):
@@ -32,29 +34,18 @@ class AceStepPipeline(BasePipeline):
self.text_encoder: AceStepTextEncoder = None
self.conditioner: AceStepConditionEncoder = None
self.dit: AceStepDiTModel = None
self.vae = None # AutoencoderOobleck (diffusers) or AceStepVAE
self.vae: AceStepVAE = None
self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer)
# Unit chain order — 7 units total
#
# 1. ShapeChecker: duration → seq_len
# 2. PromptEmbedder: prompt/lyrics → text/lyric embeddings (shared for CFG)
# 3. SilenceLatentInitializer: seq_len → src_latents + chunk_masks
# 4. ContextLatentBuilder: src_latents + chunk_masks → context_latents (shared, same for CFG+)
# 5. ConditionEmbedder: text/lyric → encoder_hidden_states (separate for CFG+/-)
# 6. NoiseInitializer: context_latents → noise
# 7. InputAudioEmbedder: noise → latents
#
# ContextLatentBuilder runs before ConditionEmbedder so that
# context_latents is available for noise shape computation.
self.in_iteration_models = ("dit",)
self.units = [
AceStepUnit_ShapeChecker(),
AceStepUnit_PromptEmbedder(),
AceStepUnit_SilenceLatentInitializer(),
AceStepUnit_ContextLatentBuilder(),
AceStepUnit_ReferenceAudioEmbedder(),
AceStepUnit_ConditionEmbedder(),
AceStepUnit_ContextLatentBuilder(),
AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(),
AceStepUnit_AudioCodeDecoder(),
]
self.model_fn = model_fn_ace_step
self.compilable_models = ["dit"]
@@ -66,7 +57,8 @@ class AceStepPipeline(BasePipeline):
torch_dtype: torch.dtype = torch.bfloat16,
device: str = get_device_type(),
model_configs: list[ModelConfig] = [],
text_tokenizer_config: ModelConfig = None,
text_tokenizer_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
vram_limit: float = None,
):
"""Load pipeline from pretrained checkpoints."""
@@ -77,11 +69,15 @@ class AceStepPipeline(BasePipeline):
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
pipe.dit = model_pool.fetch_model("ace_step_dit")
pipe.vae = model_pool.fetch_model("ace_step_vae")
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
if text_tokenizer_config is not None:
text_tokenizer_config.download_if_necessary()
from transformers import AutoTokenizer
pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
if silence_latent_config is not None:
silence_latent_config.download_if_necessary()
pipe.silence_latent = torch.load(silence_latent_config.path, weights_only=True).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
@@ -97,9 +93,19 @@ class AceStepPipeline(BasePipeline):
# Lyrics
lyrics: str = "",
# Reference audio (optional, for timbre conditioning)
reference_audio = None,
reference_audios: List[torch.Tensor] = None,
# Src audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0,
# Simple Mode: LLM-generated audio codes (optional)
audio_codes: str = None,
# Shape
duration: float = 60.0,
duration: int = 60,
# Audio Meta
bpm: Optional[int] = 100,
keyscale: Optional[str] = "B minor",
timesignature: Optional[str] = "4",
vocal_language: Optional[str] = 'zh',
# Randomness
seed: int = None,
rand_device: str = "cpu",
@@ -111,11 +117,7 @@ class AceStepPipeline(BasePipeline):
progress_bar_cmd=tqdm,
):
# 1. Scheduler
self.scheduler.set_timesteps(
num_inference_steps=num_inference_steps,
denoising_strength=1.0,
shift=shift,
)
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
# 2. 三字典输入
inputs_posi = {"prompt": prompt}
@@ -123,8 +125,11 @@ class AceStepPipeline(BasePipeline):
inputs_shared = {
"cfg_scale": cfg_scale,
"lyrics": lyrics,
"reference_audio": reference_audio,
"reference_audios": reference_audios,
"src_audio": src_audio,
"audio_codes": audio_codes,
"duration": duration,
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
"seed": seed,
"rand_device": rand_device,
"num_inference_steps": num_inference_steps,
@@ -159,6 +164,10 @@ class AceStepPipeline(BasePipeline):
# VAE returns OobleckDecoderOutput with .sample attribute
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
audio = self.output_audio_format_check(audio_output)
# Peak normalization to match target library behavior
audio = self.normalize_audio(audio, target_db=-1.0)
self.load_models_to_device([])
return audio
@@ -172,294 +181,303 @@ class AceStepPipeline(BasePipeline):
audio_output = audio_output.squeeze(0)
return audio_output.float()
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
"""Apply peak normalization to audio data, matching target library behavior.
class AceStepUnit_ShapeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
input_params=("duration",),
output_params=("duration", "seq_len"),
)
Target library reference: `acestep/audio_utils.py:normalize_audio()`
peak = max(abs(audio))
gain = 10^(target_db/20) / peak
audio = audio * gain
def process(self, pipe, duration):
# ACE-Step: 25 Hz latent rate
seq_len = int(duration * 25)
return {"duration": duration, "seq_len": seq_len}
Args:
audio: Audio tensor [C, T]
target_db: Target peak level in dB (default: -1.0)
Returns:
Normalized audio tensor
"""
peak = torch.max(torch.abs(audio))
if peak < 1e-6:
return audio
target_amp = 10 ** (target_db / 20.0)
gain = target_amp / peak
return audio * gain
class AceStepUnit_PromptEmbedder(PipelineUnit):
"""Encode prompt and lyrics using Qwen3-Embedding.
SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
Uses seperate_cfg=True to read prompt from inputs_posi (not inputs_shared).
The negative condition uses null_condition_emb (handled by ConditionEmbedder),
so negative text encoding is not needed here.
"""
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={},
input_params=("lyrics",),
input_params_nega={"prompt": "prompt"},
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language"),
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
onload_model_names=("text_encoder",)
)
def _encode_text(self, pipe, text):
def _encode_text(self, pipe, text, max_length=256):
"""Encode text using Qwen3-Embedding → [B, T, 1024]."""
if pipe.tokenizer is None:
return None, None
text_inputs = pipe.tokenizer(
text,
padding="max_length",
max_length=512,
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(pipe.device)
attention_mask = text_inputs.attention_mask.to(pipe.device)
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
hidden_states = pipe.text_encoder(input_ids, attention_mask)
return hidden_states, attention_mask
def process(self, pipe, prompt, lyrics, negative_prompt=None):
def _encode_lyrics(self, pipe, lyric_text, max_length=2048):
text_inputs = pipe.tokenizer(
lyric_text,
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(pipe.device)
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
hidden_states = pipe.text_encoder.model.embed_tokens(input_ids)
return hidden_states, attention_mask
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
bpm = meta_dict.get("bpm", "N/A")
timesignature = meta_dict.get("timesignature", "N/A")
keyscale = meta_dict.get("keyscale", "N/A")
duration = meta_dict.get("duration", 30)
duration = f"{int(duration)} seconds"
return (
f"- bpm: {bpm}\n"
f"- timesignature: {timesignature}\n"
f"- keyscale: {keyscale}\n"
f"- duration: {duration}\n"
)
def process(self, pipe, prompt, lyrics, duration, bpm, keyscale, timesignature, vocal_language):
pipe.load_models_to_device(['text_encoder'])
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
prompt = self.SFT_GEN_PROMPT.format(self.INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt)
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
lyric_hidden_states, lyric_attention_mask = self._encode_lyrics(pipe, lyric_text, max_length=2048)
# Lyrics encoding — use empty string if not provided
lyric_text = lyrics if lyrics else ""
lyric_hidden_states, lyric_attention_mask = self._encode_text(pipe, lyric_text)
if text_hidden_states is not None and lyric_hidden_states is not None:
return {
"text_hidden_states": text_hidden_states,
"text_attention_mask": text_attention_mask,
"lyric_hidden_states": lyric_hidden_states,
"lyric_attention_mask": lyric_attention_mask,
}
return {}
# TODO: remove this
newtext = prompt + "\n\n" + lyric_text
return {
"text_hidden_states": text_hidden_states,
"text_attention_mask": text_attention_mask,
"lyric_hidden_states": lyric_hidden_states,
"lyric_attention_mask": lyric_attention_mask,
}
class AceStepUnit_SilenceLatentInitializer(PipelineUnit):
"""Generate silence latent (all zeros) and chunk_masks for text2music.
Target library reference: `prepare_condition()` line 1698-1699:
context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
For text2music mode:
- src_latents = zeros [B, T, 64] (VAE latent dimension)
- chunk_masks = ones [B, T, 64] (full visibility mask for text2music)
- context_latents = [B, T, 128] (concat of src_latents + chunk_masks)
"""
class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("seq_len",),
output_params=("silence_latent", "src_latents", "chunk_masks"),
input_params=("reference_audios",),
output_params=("reference_latents", "refer_audio_order_mask"),
onload_model_names=("vae",)
)
def process(self, pipe, seq_len):
# silence_latent shape: [B, T, 64] — 64 is the VAE latent dimension
silence_latent = torch.zeros(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype)
# For text2music: src_latents = silence_latent
src_latents = silence_latent.clone()
def process(self, pipe, reference_audios):
pipe.load_models_to_device(['vae'])
if reference_audios is not None and len(reference_audios) > 0:
# TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask
pass
else:
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
# chunk_masks: [B, T, 64] of ones (same shape as src_latents)
# In text2music mode (is_covers=0), chunk_masks are all 1.0
# This matches the target library's behavior at line 1699
chunk_masks = torch.ones(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype)
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Infer packed reference-audio latents and order mask."""
refer_audio_order_mask = []
refer_audio_latents = []
return {"silence_latent": silence_latent, "src_latents": src_latents, "chunk_masks": chunk_masks}
def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
if not isinstance(a, torch.Tensor):
raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}")
if a.dim() == 3 and a.shape[0] == 1:
a = a.squeeze(0)
if a.dim() == 1:
a = a.unsqueeze(0)
if a.dim() != 2:
raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}")
if a.shape[0] == 1:
a = torch.cat([a, a], dim=0)
return a[:2]
def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor:
if z.dim() == 4 and z.shape[0] == 1:
z = z.squeeze(0)
if z.dim() == 2:
z = z.unsqueeze(0)
return z
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
"""Build context_latents from src_latents and chunk_masks.
refer_encode_cache: Dict[int, torch.Tensor] = {}
for batch_idx, refer_audios in enumerate(refer_audioss):
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
refer_audio_latent = _ensure_latent_3d(pipe.silence_latent[:, :750, :])
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
else:
# TODO: check
for refer_audio in refer_audios:
cache_key = refer_audio.data_ptr()
if cache_key in refer_encode_cache:
refer_audio_latent = refer_encode_cache[cache_key].clone()
else:
refer_audio = _normalize_audio_2d(refer_audio)
refer_audio_latent = pipe.vae.encode(refer_audio)
refer_audio_latent = refer_audio_latent.to(dtype=pipe.torch_dtype, device=pipe.device)
if refer_audio_latent.dim() == 2:
refer_audio_latent = refer_audio_latent.unsqueeze(0)
refer_audio_latent = _ensure_latent_3d(refer_audio_latent.transpose(1, 2))
refer_encode_cache[cache_key] = refer_audio_latent
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
Target library reference: `prepare_condition()` line 1699:
context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
context_latents is the SAME for positive and negative CFG paths
(it comes from src_latents + chunk_masks, not from text encoding).
So this is a普通模式 Unit — outputs go to inputs_shared.
"""
def __init__(self):
super().__init__(
input_params=("src_latents", "chunk_masks"),
output_params=("context_latents", "attention_mask"),
)
def process(self, pipe, src_latents, chunk_masks):
# context_latents: cat([src_latents, chunk_masks], dim=-1) → [B, T, 128]
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
# attention_mask for the DiT: ones [B, T]
# The target library uses this for cross-attention with context_latents
attention_mask = torch.ones(src_latents.shape[0], src_latents.shape[1],
device=pipe.device, dtype=pipe.torch_dtype)
return {"context_latents": context_latents, "attention_mask": attention_mask}
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
return refer_audio_latents, refer_audio_order_mask
class AceStepUnit_ConditionEmbedder(PipelineUnit):
"""Generate encoder_hidden_states via ACEStepConditioner.
Target library reference: `prepare_condition()` line 1674-1681:
encoder_hidden_states, encoder_attention_mask = self.encoder(...)
Uses seperate_cfg mode:
- Positive: encode with full condition (text + lyrics + reference audio)
- Negative: replace text with null_condition_emb, keep lyrics/timbre same
context_latents is handled by ContextLatentBuilder (普通模式), not here.
"""
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={
"text_hidden_states": "text_hidden_states",
"text_attention_mask": "text_attention_mask",
"lyric_hidden_states": "lyric_hidden_states",
"lyric_attention_mask": "lyric_attention_mask",
"reference_audio": "reference_audio",
"refer_audio_order_mask": "refer_audio_order_mask",
},
input_params_nega={},
input_params=("cfg_scale",),
output_params=(
"encoder_hidden_states", "encoder_attention_mask",
"negative_encoder_hidden_states", "negative_encoder_attention_mask",
),
onload_model_names=("conditioner",)
input_params_posi={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
input_params_nega={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
input_params=("reference_latents", "refer_audio_order_mask"),
output_params=("encoder_hidden_states", "encoder_attention_mask"),
onload_model_names=("conditioner",),
)
def _prepare_condition(self, pipe, text_hidden_states, text_attention_mask,
lyric_hidden_states, lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=None,
refer_audio_order_mask=None):
"""Call ACEStepConditioner forward to produce encoder_hidden_states."""
def process(self, pipe, text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, reference_latents, refer_audio_order_mask):
pipe.load_models_to_device(['conditioner'])
# Handle reference audio
if refer_audio_acoustic_hidden_states_packed is None:
# No reference audio: create 2D packed zeros [N=1, d=64]
# TimbreEncoder.unpack expects [N, d], not [B, T, d]
refer_audio_acoustic_hidden_states_packed = torch.zeros(
1, 64, device=pipe.device, dtype=pipe.torch_dtype
)
refer_audio_order_mask = torch.LongTensor([0]).to(pipe.device)
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
reference_latents=reference_latents,
refer_audio_order_mask=refer_audio_order_mask,
)
return {"encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask}
return encoder_hidden_states, encoder_attention_mask
def _prepare_negative_condition(self, pipe, lyric_hidden_states, lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=None,
refer_audio_order_mask=None):
"""Generate negative condition using null_condition_emb."""
if pipe.conditioner is None or not hasattr(pipe.conditioner, 'null_condition_emb'):
return None, None
null_emb = pipe.conditioner.null_condition_emb # [1, 1, hidden_size]
bsz = 1
if lyric_hidden_states is not None:
bsz = lyric_hidden_states.shape[0]
null_hidden_states = null_emb.expand(bsz, -1, -1)
null_attn_mask = torch.ones(bsz, 1, device=pipe.device, dtype=pipe.torch_dtype)
# For negative: use null_condition_emb as text, keep lyrics and timbre
neg_encoder_hidden_states, neg_encoder_attention_mask = pipe.conditioner(
text_hidden_states=null_hidden_states,
text_attention_mask=null_attn_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("duration", "src_audio"),
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
)
return neg_encoder_hidden_states, neg_encoder_attention_mask
def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
available = pipe.silence_latent.shape[1]
if length <= available:
return pipe.silence_latent[0, :length, :]
repeats = (length + available - 1) // available
tiled = pipe.silence_latent[0].repeat(repeats, 1)
return tiled[:length, :]
def process(self, pipe, text_hidden_states, text_attention_mask,
lyric_hidden_states, lyric_attention_mask,
reference_audio=None, refer_audio_order_mask=None,
negative_prompt=None, cfg_scale=1.0):
# Positive condition
pos_enc_hs, pos_enc_mask = self._prepare_condition(
pipe, text_hidden_states, text_attention_mask,
lyric_hidden_states, lyric_attention_mask,
None, refer_audio_order_mask,
)
# Negative condition: only needed when CFG is active (cfg_scale > 1.0)
# For cfg_scale=1.0 (turbo), skip to avoid null_condition_emb dimension mismatch
result = {
"encoder_hidden_states": pos_enc_hs,
"encoder_attention_mask": pos_enc_mask,
}
if cfg_scale > 1.0:
neg_enc_hs, neg_enc_mask = self._prepare_negative_condition(
pipe, lyric_hidden_states, lyric_attention_mask,
None, refer_audio_order_mask,
)
if neg_enc_hs is not None:
result["negative_encoder_hidden_states"] = neg_enc_hs
result["negative_encoder_attention_mask"] = neg_enc_mask
return result
def process(self, pipe, duration, src_audio):
if src_audio is not None:
raise NotImplementedError("Src audio conditioning is not implemented yet. Please set src_audio to None.")
else:
max_latent_length = duration * pipe.sample_rate // 1920
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
return {"context_latents": context_latents, "attention_mask": attention_mask}
class AceStepUnit_NoiseInitializer(PipelineUnit):
"""Generate initial noise tensor.
Target library reference: `prepare_noise()` line 1781-1818:
src_latents_shape = (bsz, context_latents.shape[1], context_latents.shape[-1] // 2)
Noise shape = [B, T, context_latents.shape[-1] // 2] = [B, T, 128 // 2] = [B, T, 64]
"""
def __init__(self):
super().__init__(
input_params=("seed", "seq_len", "rand_device", "context_latents"),
input_params=("context_latents", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe, seed, seq_len, rand_device, context_latents):
# Noise shape: [B, T, context_latents.shape[-1] // 2]
# context_latents = [B, T, 128] → noise = [B, T, 64]
# This matches the target library's prepare_noise() at line 1796
noise_shape = (context_latents.shape[0], context_latents.shape[1],
context_latents.shape[-1] // 2)
noise = pipe.generate_noise(
noise_shape,
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
)
def process(self, pipe, context_latents, seed, rand_device):
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
"""Set up latents for denoise loop.
For text2music (no input audio): latents = noise, input_latents = None.
Target library reference: `generate_audio()` line 1972:
xt = noise (when cover_noise_strength == 0)
"""
def __init__(self):
super().__init__(
input_params=("noise",),
input_params=("noise", "input_audio"),
output_params=("latents", "input_latents"),
)
def process(self, pipe, noise):
# For text2music: start from pure noise
def process(self, pipe, noise, input_audio):
if input_audio is None:
return {"latents": noise}
# TODO: support for train
return {"latents": noise, "input_latents": None}
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("audio_codes", "seq_len", "silence_latent"),
output_params=("lm_hints_25Hz",),
onload_model_names=("tokenizer_model",),
)
@staticmethod
def _parse_audio_code_string(code_str: str) -> list:
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
if not code_str:
return []
codes = []
max_audio_code = 63999
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
code_value = int(x)
codes.append(max(0, min(code_value, max_audio_code)))
return codes
def process(self, pipe, audio_codes, seq_len, silence_latent):
if audio_codes is None or not audio_codes.strip():
return {"lm_hints_25Hz": None}
code_ids = self._parse_audio_code_string(audio_codes)
if len(code_ids) == 0:
return {"lm_hints_25Hz": None}
pipe.load_models_to_device(["tokenizer_model"])
quantizer = pipe.tokenizer_model.tokenizer.quantizer
detokenizer = pipe.tokenizer_model.detokenizer
indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
quantized = quantizer.get_output_from_indices(indices) # [1, N, 2048]
if quantized.dtype != pipe.torch_dtype:
quantized = quantized.to(pipe.torch_dtype)
lm_hints = detokenizer(quantized) # [1, N*5, 64]
# Pad or truncate to seq_len
current_len = lm_hints.shape[1]
if current_len < seq_len:
pad_len = seq_len - current_len
pad = silence_latent[:, :pad_len, :]
lm_hints = torch.cat([lm_hints, pad], dim=1)
elif current_len > seq_len:
lm_hints = lm_hints[:, :seq_len, :]
return {"lm_hints_25Hz": lm_hints}
def model_fn_ace_step(
dit: AceStepDiTModel,
latents=None,
@@ -468,49 +486,11 @@ def model_fn_ace_step(
encoder_attention_mask=None,
context_latents=None,
attention_mask=None,
past_key_values=None,
negative_encoder_hidden_states=None,
negative_encoder_attention_mask=None,
negative_context_latents=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
"""Model function for ACE-Step DiT forward.
Timestep is already in [0, 1] range — no scaling needed.
Target library reference: `generate_audio()` line 2009-2020:
decoder_outputs = self.decoder(
hidden_states=x, timestep=t_curr_tensor, timestep_r=t_curr_tensor,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
use_cache=True, past_key_values=past_key_values,
)
Args:
dit: AceStepDiTModel
latents: [B, T, 64] noise/latent tensor (same shape as src_latents)
timestep: scalar tensor in [0, 1]
encoder_hidden_states: [B, T_text, 2048] condition from Conditioner
(positive or negative depending on CFG pass — the cfg_guided_model_fn
passes inputs_posi for positive, inputs_nega for negative)
encoder_attention_mask: [B, T_text]
context_latents: [B, T, 128] = cat([src_latents, chunk_masks], dim=-1)
(same for both CFG+/- paths in text2music mode)
attention_mask: [B, T] ones mask for DiT
past_key_values: EncoderDecoderCache for KV caching
The DiT internally concatenates: cat([context_latents, latents], dim=-1) = [B, T, 192]
as the actual input (128 + 64 = 192 channels).
"""
# ACE-Step uses timestep directly in [0, 1] range — no /1000 scaling
timestep = timestep.squeeze()
# Expand timestep to match batch size
bsz = latents.shape[0]
timestep = timestep.expand(bsz)
timestep = timestep.unsqueeze(0)
decoder_outputs = dit(
hidden_states=latents,
timestep=timestep,
@@ -519,9 +499,5 @@ def model_fn_ace_step(
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
use_cache=True,
past_key_values=past_key_values,
)
# Return velocity prediction (first element of decoder_outputs)
return decoder_outputs[0]
)[0]
return decoder_outputs

View File

@@ -99,6 +99,7 @@ def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend
"""
if waveform.dim() == 3:
waveform = waveform[0]
waveform.cpu()
if backend == "torchcodec":
from torchcodec.encoders import AudioEncoder

View File

@@ -2,8 +2,8 @@
Ace-Step 1.5 — Text-to-Music with Simple Mode (LLM expansion).
Uses the ACE-Step LLM to expand a simple description into structured
parameters (caption, lyrics, bpm, keyscale, etc.), then feeds them
to the DiffSynth Pipeline.
parameters (caption, lyrics, bpm, keyscale, etc.) AND audio codes,
then feeds them to the DiffSynth Pipeline.
The LLM expansion uses the target library's LLMHandler. If vLLM is
not available, it falls back to using pre-structured parameters.
@@ -47,11 +47,14 @@ def try_load_llm_handler(checkpoint_dir: str, lm_model_path: str = "acestep-5Hz-
def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
"""Expand a simple description using LLM Chain-of-Thought."""
"""Expand a simple description using LLM Chain-of-Thought.
Returns (params_dict, audio_codes_string).
"""
result = llm_handler.generate_with_stop_condition(
caption=description,
lyrics="",
infer_type="dit", # metadata only
infer_type="dit", # metadata + audio codes
temperature=0.85,
cfg_scale=1.0,
use_cot_metas=True,
@@ -62,7 +65,7 @@ def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
if result.get("success") and result.get("metadata"):
meta = result["metadata"]
return {
params = {
"caption": meta.get("caption", description),
"lyrics": meta.get("lyrics", ""),
"bpm": meta.get("bpm", 100),
@@ -71,9 +74,11 @@ def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
"timesignature": meta.get("timesignature", "4"),
"duration": meta.get("duration", duration),
}
audio_codes = result.get("audio_codes", "")
return params, audio_codes
print(f"[Simple Mode] LLM expansion failed: {result.get('error', 'unknown')}")
return None
return None, ""
def fallback_expand(description: str, duration: float = 30.0):
@@ -87,7 +92,7 @@ def fallback_expand(description: str, duration: float = 30.0):
"language": "en",
"timesignature": "4",
"duration": duration,
}
}, ""
# ---------------------------------------------------------------------------
@@ -114,13 +119,13 @@ def main():
lm_model_path="acestep-5Hz-lm-1.7B",
)
# 2. Expand parameters
# 2. Expand parameters + audio codes
if llm_ok:
params = expand_with_llm(llm_handler, description, duration=duration)
params, audio_codes = expand_with_llm(llm_handler, description, duration=duration)
if params is None:
params = fallback_expand(description, duration)
params, audio_codes = fallback_expand(description, duration)
else:
params = fallback_expand(description, duration)
params, audio_codes = fallback_expand(description, duration)
print(f"\n[Simple Mode] Parameters:")
print(f" Caption: {params['caption'][:100]}...")
@@ -128,6 +133,7 @@ def main():
print(f" BPM: {params['bpm']}, Keyscale: {params['keyscale']}")
print(f" Language: {params['language']}, Time Sig: {params['timesignature']}")
print(f" Duration: {params['duration']}s")
print(f" Audio codes: {len(audio_codes)} chars" if audio_codes else " Audio codes: None (fallback)")
# 3. Load Pipeline
print(f"\n[Pipeline] Loading Ace-Step 1.5 (turbo)...")
@@ -141,21 +147,17 @@ def main():
),
ModelConfig(
model_id="ACE-Step/Ace-Step1.5",
origin_file_pattern="acestep-v15-turbo/model.safetensors"
origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
),
ModelConfig(
model_id="ACE-Step/Ace-Step1.5",
origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
origin_file_pattern="vae/diffusion_pytorch_model.safetensors"
),
],
tokenizer_config=ModelConfig(
text_tokenizer_config=ModelConfig(
model_id="ACE-Step/Ace-Step1.5",
origin_file_pattern="Qwen3-Embedding-0.6B/"
),
vae_config=ModelConfig(
model_id="ACE-Step/Ace-Step1.5",
origin_file_pattern="vae/"
),
)
# 4. Generate
@@ -164,6 +166,7 @@ def main():
prompt=params["caption"],
lyrics=params["lyrics"],
duration=params["duration"],
audio_codes=audio_codes if audio_codes else None,
seed=42,
num_inference_steps=8,
cfg_scale=1.0,

View File

@@ -1,16 +1,6 @@
"""
Ace-Step 1.5 — Text-to-Music (Turbo) inference example.
Demonstrates the standard text2music pipeline with structured parameters
(caption, lyrics, duration, etc.) — no LLM expansion needed.
For Simple Mode (LLM expands a short description), see:
- Ace-Step1.5-SimpleMode.py
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
import torch
import soundfile as sf
from diffsynth.utils.data.audio import save_audio
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
@@ -35,29 +25,21 @@ pipe = AceStepPipeline.from_pretrained(
),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline."
lyrics = """[Intro - Synth Brass Fanfare]
[Verse 1]
黑夜里的风吹过耳畔
甜蜜时光转瞬即逝
脚步飘摇在星光上
[Chorus]
心电感应在震动间
拥抱未来勇敢冒险
[Outro - Instrumental]"""
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=30.0,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=8,
cfg_scale=1.0,
shift=3.0,
)
sf.write("Ace-Step1.5.wav", audio.cpu().numpy(), pipe.sample_rate)
save_audio(audio.cpu(), pipe.vae.sampling_rate, "Ace-Step1.5.wav")
print(f"Saved to Ace-Step1.5.wav, shape: {audio.shape}, duration: {audio.shape[-1] / pipe.sample_rate:.1f}s")