acestep t2m

This commit is contained in:
mi804
2026-04-21 13:16:15 +08:00
parent a604d76339
commit 9d09e0431c
9 changed files with 300 additions and 377 deletions

View File

@@ -157,10 +157,10 @@ class FlowMatchScheduler():
""" """
num_train_timesteps = 1000 num_train_timesteps = 1000
sigma_start = denoising_strength sigma_start = denoising_strength
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps) sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
if shift is not None and shift != 1.0: if shift is not None and shift != 1.0:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas # ACE-Step uses [0, 1] range directly timesteps = sigmas * num_train_timesteps
return sigmas, timesteps return sigmas, timesteps
@staticmethod @staticmethod

View File

@@ -540,17 +540,9 @@ class AceStepTimbreEncoder(nn.Module):
) -> BaseModelOutput: ) -> BaseModelOutput:
inputs_embeds = refer_audio_acoustic_hidden_states_packed inputs_embeds = refer_audio_acoustic_hidden_states_packed
inputs_embeds = self.embed_tokens(inputs_embeds) inputs_embeds = self.embed_tokens(inputs_embeds)
# Handle 2D (packed) or 3D (batched) input seq_len = inputs_embeds.shape[1]
is_packed = inputs_embeds.dim() == 2 cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
if is_packed: position_ids = cache_position.unsqueeze(0)
seq_len = inputs_embeds.shape[0]
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
inputs_embeds = inputs_embeds.unsqueeze(0)
else:
seq_len = inputs_embeds.shape[1]
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
dtype = inputs_embeds.dtype dtype = inputs_embeds.dtype
device = inputs_embeds.device device = inputs_embeds.device
@@ -586,9 +578,8 @@ class AceStepTimbreEncoder(nn.Module):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
hidden_states = hidden_states[:, 0, :]
# For packed input: reshape [1, T, D] -> [T, D] for unpacking # For packed input: reshape [1, T, D] -> [T, D] for unpacking
if is_packed:
hidden_states = hidden_states.squeeze(0)
timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask) timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
return timbre_embs_unpack, timbre_embs_mask return timbre_embs_unpack, timbre_embs_mask
@@ -686,7 +677,7 @@ class AceStepConditionEncoder(nn.Module):
text_attention_mask: Optional[torch.Tensor] = None, text_attention_mask: Optional[torch.Tensor] = None,
lyric_hidden_states: Optional[torch.LongTensor] = None, lyric_hidden_states: Optional[torch.LongTensor] = None,
lyric_attention_mask: Optional[torch.Tensor] = None, lyric_attention_mask: Optional[torch.Tensor] = None,
refer_audio_acoustic_hidden_states_packed: Optional[torch.Tensor] = None, reference_latents: Optional[torch.Tensor] = None,
refer_audio_order_mask: Optional[torch.LongTensor] = None, refer_audio_order_mask: Optional[torch.LongTensor] = None,
): ):
text_hidden_states = self.text_projector(text_hidden_states) text_hidden_states = self.text_projector(text_hidden_states)
@@ -695,11 +686,7 @@ class AceStepConditionEncoder(nn.Module):
attention_mask=lyric_attention_mask, attention_mask=lyric_attention_mask,
) )
lyric_hidden_states = lyric_encoder_outputs.last_hidden_state lyric_hidden_states = lyric_encoder_outputs.last_hidden_state
timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder( timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(reference_latents, refer_audio_order_mask)
refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask
)
encoder_hidden_states, encoder_attention_mask = pack_sequences( encoder_hidden_states, encoder_attention_mask = pack_sequences(
lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask
) )

View File

@@ -165,7 +165,7 @@ class TimestepEmbedding(nn.Module):
self, self,
in_channels: int, in_channels: int,
time_embed_dim: int, time_embed_dim: int,
scale: float = 1000, scale: float = 1,
): ):
super().__init__() super().__init__()
@@ -711,7 +711,7 @@ class AceStepDiTModel(nn.Module):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor, encoder_attention_mask: torch.Tensor,
context_latents: torch.Tensor, context_latents: torch.Tensor,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = False,
past_key_values: Optional[EncoderDecoderCache] = None, past_key_values: Optional[EncoderDecoderCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,

View File

@@ -2,17 +2,6 @@ import torch
class AceStepTextEncoder(torch.nn.Module): class AceStepTextEncoder(torch.nn.Module):
"""
Text encoder for ACE-Step using Qwen3-Embedding-0.6B.
Converts text/lyric tokens to hidden state embeddings that are
further processed by the ACE-Step ConditionEncoder.
Wraps a Qwen3Model transformers model. Config is manually
constructed, and model weights are loaded via DiffSynth's
standard mechanism from safetensors files.
"""
def __init__( def __init__(
self, self,
): ):
@@ -49,8 +38,6 @@ class AceStepTextEncoder(torch.nn.Module):
) )
self.model = Qwen3Model(config) self.model = Qwen3Model(config)
self.config = config
self.hidden_size = config.hidden_size
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -58,23 +45,9 @@ class AceStepTextEncoder(torch.nn.Module):
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
): ):
"""
Encode text/lyric tokens to hidden states.
Args:
input_ids: [B, T] token IDs
attention_mask: [B, T] attention mask
Returns:
last_hidden_state: [B, T, hidden_size]
"""
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
return_dict=True, return_dict=True,
) )
return outputs.last_hidden_state return outputs.last_hidden_state
def to(self, *args, **kwargs):
self.model.to(*args, **kwargs)
return self

View File

@@ -226,6 +226,7 @@ class AceStepVAE(nn.Module):
upsampling_ratios=upsampling_ratios, upsampling_ratios=upsampling_ratios,
channel_multiples=channel_multiples, channel_multiples=channel_multiples,
) )
self.sampling_rate = sampling_rate
def encode(self, x: torch.Tensor) -> torch.Tensor: def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Audio waveform [B, audio_channels, T] → latent [B, encoder_hidden_size, T'].""" """Audio waveform [B, audio_channels, T] → latent [B, encoder_hidden_size, T']."""

View File

@@ -3,8 +3,9 @@ ACE-Step Pipeline for DiffSynth-Studio.
Text-to-Music generation pipeline using ACE-Step 1.5 model. Text-to-Music generation pipeline using ACE-Step 1.5 model.
""" """
import re
import torch import torch
from typing import Optional from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm from tqdm import tqdm
from ..core.device.npu_compatible_device import get_device_type from ..core.device.npu_compatible_device import get_device_type
@@ -16,6 +17,7 @@ from ..models.ace_step_dit import AceStepDiTModel
from ..models.ace_step_conditioner import AceStepConditionEncoder from ..models.ace_step_conditioner import AceStepConditionEncoder
from ..models.ace_step_text_encoder import AceStepTextEncoder from ..models.ace_step_text_encoder import AceStepTextEncoder
from ..models.ace_step_vae import AceStepVAE from ..models.ace_step_vae import AceStepVAE
from ..models.ace_step_tokenizer import AceStepTokenizer
class AceStepPipeline(BasePipeline): class AceStepPipeline(BasePipeline):
@@ -32,29 +34,18 @@ class AceStepPipeline(BasePipeline):
self.text_encoder: AceStepTextEncoder = None self.text_encoder: AceStepTextEncoder = None
self.conditioner: AceStepConditionEncoder = None self.conditioner: AceStepConditionEncoder = None
self.dit: AceStepDiTModel = None self.dit: AceStepDiTModel = None
self.vae = None # AutoencoderOobleck (diffusers) or AceStepVAE self.vae: AceStepVAE = None
self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer)
# Unit chain order — 7 units total
#
# 1. ShapeChecker: duration → seq_len
# 2. PromptEmbedder: prompt/lyrics → text/lyric embeddings (shared for CFG)
# 3. SilenceLatentInitializer: seq_len → src_latents + chunk_masks
# 4. ContextLatentBuilder: src_latents + chunk_masks → context_latents (shared, same for CFG+)
# 5. ConditionEmbedder: text/lyric → encoder_hidden_states (separate for CFG+/-)
# 6. NoiseInitializer: context_latents → noise
# 7. InputAudioEmbedder: noise → latents
#
# ContextLatentBuilder runs before ConditionEmbedder so that
# context_latents is available for noise shape computation.
self.in_iteration_models = ("dit",) self.in_iteration_models = ("dit",)
self.units = [ self.units = [
AceStepUnit_ShapeChecker(),
AceStepUnit_PromptEmbedder(), AceStepUnit_PromptEmbedder(),
AceStepUnit_SilenceLatentInitializer(), AceStepUnit_ReferenceAudioEmbedder(),
AceStepUnit_ContextLatentBuilder(),
AceStepUnit_ConditionEmbedder(), AceStepUnit_ConditionEmbedder(),
AceStepUnit_ContextLatentBuilder(),
AceStepUnit_NoiseInitializer(), AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(), AceStepUnit_InputAudioEmbedder(),
AceStepUnit_AudioCodeDecoder(),
] ]
self.model_fn = model_fn_ace_step self.model_fn = model_fn_ace_step
self.compilable_models = ["dit"] self.compilable_models = ["dit"]
@@ -66,7 +57,8 @@ class AceStepPipeline(BasePipeline):
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: str = get_device_type(), device: str = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
text_tokenizer_config: ModelConfig = None, text_tokenizer_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
vram_limit: float = None, vram_limit: float = None,
): ):
"""Load pipeline from pretrained checkpoints.""" """Load pipeline from pretrained checkpoints."""
@@ -77,11 +69,15 @@ class AceStepPipeline(BasePipeline):
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner") pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
pipe.dit = model_pool.fetch_model("ace_step_dit") pipe.dit = model_pool.fetch_model("ace_step_dit")
pipe.vae = model_pool.fetch_model("ace_step_vae") pipe.vae = model_pool.fetch_model("ace_step_vae")
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
if text_tokenizer_config is not None: if text_tokenizer_config is not None:
text_tokenizer_config.download_if_necessary() text_tokenizer_config.download_if_necessary()
from transformers import AutoTokenizer from transformers import AutoTokenizer
pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path) pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
if silence_latent_config is not None:
silence_latent_config.download_if_necessary()
pipe.silence_latent = torch.load(silence_latent_config.path, weights_only=True).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
# VRAM Management # VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state() pipe.vram_management_enabled = pipe.check_vram_management_state()
@@ -97,9 +93,19 @@ class AceStepPipeline(BasePipeline):
# Lyrics # Lyrics
lyrics: str = "", lyrics: str = "",
# Reference audio (optional, for timbre conditioning) # Reference audio (optional, for timbre conditioning)
reference_audio = None, reference_audios: List[torch.Tensor] = None,
# Src audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0,
# Simple Mode: LLM-generated audio codes (optional)
audio_codes: str = None,
# Shape # Shape
duration: float = 60.0, duration: int = 60,
# Audio Meta
bpm: Optional[int] = 100,
keyscale: Optional[str] = "B minor",
timesignature: Optional[str] = "4",
vocal_language: Optional[str] = 'zh',
# Randomness # Randomness
seed: int = None, seed: int = None,
rand_device: str = "cpu", rand_device: str = "cpu",
@@ -111,11 +117,7 @@ class AceStepPipeline(BasePipeline):
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
): ):
# 1. Scheduler # 1. Scheduler
self.scheduler.set_timesteps( self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
num_inference_steps=num_inference_steps,
denoising_strength=1.0,
shift=shift,
)
# 2. 三字典输入 # 2. 三字典输入
inputs_posi = {"prompt": prompt} inputs_posi = {"prompt": prompt}
@@ -123,8 +125,11 @@ class AceStepPipeline(BasePipeline):
inputs_shared = { inputs_shared = {
"cfg_scale": cfg_scale, "cfg_scale": cfg_scale,
"lyrics": lyrics, "lyrics": lyrics,
"reference_audio": reference_audio, "reference_audios": reference_audios,
"src_audio": src_audio,
"audio_codes": audio_codes,
"duration": duration, "duration": duration,
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
"seed": seed, "seed": seed,
"rand_device": rand_device, "rand_device": rand_device,
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
@@ -159,6 +164,10 @@ class AceStepPipeline(BasePipeline):
# VAE returns OobleckDecoderOutput with .sample attribute # VAE returns OobleckDecoderOutput with .sample attribute
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
audio = self.output_audio_format_check(audio_output) audio = self.output_audio_format_check(audio_output)
# Peak normalization to match target library behavior
audio = self.normalize_audio(audio, target_db=-1.0)
self.load_models_to_device([]) self.load_models_to_device([])
return audio return audio
@@ -172,294 +181,303 @@ class AceStepPipeline(BasePipeline):
audio_output = audio_output.squeeze(0) audio_output = audio_output.squeeze(0)
return audio_output.float() return audio_output.float()
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
"""Apply peak normalization to audio data, matching target library behavior.
class AceStepUnit_ShapeChecker(PipelineUnit): Target library reference: `acestep/audio_utils.py:normalize_audio()`
"""Check and compute sequence length from duration.""" peak = max(abs(audio))
def __init__(self): gain = 10^(target_db/20) / peak
super().__init__( audio = audio * gain
input_params=("duration",),
output_params=("duration", "seq_len"),
)
def process(self, pipe, duration): Args:
# ACE-Step: 25 Hz latent rate audio: Audio tensor [C, T]
seq_len = int(duration * 25) target_db: Target peak level in dB (default: -1.0)
return {"duration": duration, "seq_len": seq_len}
Returns:
Normalized audio tensor
"""
peak = torch.max(torch.abs(audio))
if peak < 1e-6:
return audio
target_amp = 10 ** (target_db / 20.0)
gain = target_amp / peak
return audio * gain
class AceStepUnit_PromptEmbedder(PipelineUnit): class AceStepUnit_PromptEmbedder(PipelineUnit):
"""Encode prompt and lyrics using Qwen3-Embedding. SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
Uses seperate_cfg=True to read prompt from inputs_posi (not inputs_shared).
The negative condition uses null_condition_emb (handled by ConditionEmbedder),
so negative text encoding is not needed here.
"""
def __init__(self): def __init__(self):
super().__init__( super().__init__(
seperate_cfg=True, seperate_cfg=True,
input_params_posi={"prompt": "prompt"}, input_params_posi={"prompt": "prompt"},
input_params_nega={}, input_params_nega={"prompt": "prompt"},
input_params=("lyrics",), input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language"),
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"), output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
onload_model_names=("text_encoder",) onload_model_names=("text_encoder",)
) )
def _encode_text(self, pipe, text): def _encode_text(self, pipe, text, max_length=256):
"""Encode text using Qwen3-Embedding → [B, T, 1024].""" """Encode text using Qwen3-Embedding → [B, T, 1024]."""
if pipe.tokenizer is None:
return None, None
text_inputs = pipe.tokenizer( text_inputs = pipe.tokenizer(
text, text,
padding="max_length", max_length=max_length,
max_length=512,
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
) )
input_ids = text_inputs.input_ids.to(pipe.device) input_ids = text_inputs.input_ids.to(pipe.device)
attention_mask = text_inputs.attention_mask.to(pipe.device) attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
hidden_states = pipe.text_encoder(input_ids, attention_mask) hidden_states = pipe.text_encoder(input_ids, attention_mask)
return hidden_states, attention_mask return hidden_states, attention_mask
def process(self, pipe, prompt, lyrics, negative_prompt=None): def _encode_lyrics(self, pipe, lyric_text, max_length=2048):
text_inputs = pipe.tokenizer(
lyric_text,
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(pipe.device)
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
hidden_states = pipe.text_encoder.model.embed_tokens(input_ids)
return hidden_states, attention_mask
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
bpm = meta_dict.get("bpm", "N/A")
timesignature = meta_dict.get("timesignature", "N/A")
keyscale = meta_dict.get("keyscale", "N/A")
duration = meta_dict.get("duration", 30)
duration = f"{int(duration)} seconds"
return (
f"- bpm: {bpm}\n"
f"- timesignature: {timesignature}\n"
f"- keyscale: {keyscale}\n"
f"- duration: {duration}\n"
)
def process(self, pipe, prompt, lyrics, duration, bpm, keyscale, timesignature, vocal_language):
pipe.load_models_to_device(['text_encoder']) pipe.load_models_to_device(['text_encoder'])
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
prompt = self.SFT_GEN_PROMPT.format(self.INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt) lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
lyric_hidden_states, lyric_attention_mask = self._encode_lyrics(pipe, lyric_text, max_length=2048)
# Lyrics encoding — use empty string if not provided # TODO: remove this
lyric_text = lyrics if lyrics else "" newtext = prompt + "\n\n" + lyric_text
lyric_hidden_states, lyric_attention_mask = self._encode_text(pipe, lyric_text) return {
"text_hidden_states": text_hidden_states,
if text_hidden_states is not None and lyric_hidden_states is not None: "text_attention_mask": text_attention_mask,
return { "lyric_hidden_states": lyric_hidden_states,
"text_hidden_states": text_hidden_states, "lyric_attention_mask": lyric_attention_mask,
"text_attention_mask": text_attention_mask, }
"lyric_hidden_states": lyric_hidden_states,
"lyric_attention_mask": lyric_attention_mask,
}
return {}
class AceStepUnit_SilenceLatentInitializer(PipelineUnit): class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
"""Generate silence latent (all zeros) and chunk_masks for text2music.
Target library reference: `prepare_condition()` line 1698-1699:
context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
For text2music mode:
- src_latents = zeros [B, T, 64] (VAE latent dimension)
- chunk_masks = ones [B, T, 64] (full visibility mask for text2music)
- context_latents = [B, T, 128] (concat of src_latents + chunk_masks)
"""
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("seq_len",), input_params=("reference_audios",),
output_params=("silence_latent", "src_latents", "chunk_masks"), output_params=("reference_latents", "refer_audio_order_mask"),
onload_model_names=("vae",)
) )
def process(self, pipe, seq_len): def process(self, pipe, reference_audios):
# silence_latent shape: [B, T, 64] — 64 is the VAE latent dimension pipe.load_models_to_device(['vae'])
silence_latent = torch.zeros(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype) if reference_audios is not None and len(reference_audios) > 0:
# For text2music: src_latents = silence_latent # TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask
src_latents = silence_latent.clone() pass
else:
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
# chunk_masks: [B, T, 64] of ones (same shape as src_latents) def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
# In text2music mode (is_covers=0), chunk_masks are all 1.0 """Infer packed reference-audio latents and order mask."""
# This matches the target library's behavior at line 1699 refer_audio_order_mask = []
chunk_masks = torch.ones(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype) refer_audio_latents = []
return {"silence_latent": silence_latent, "src_latents": src_latents, "chunk_masks": chunk_masks} def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
if not isinstance(a, torch.Tensor):
raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}")
if a.dim() == 3 and a.shape[0] == 1:
a = a.squeeze(0)
if a.dim() == 1:
a = a.unsqueeze(0)
if a.dim() != 2:
raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}")
if a.shape[0] == 1:
a = torch.cat([a, a], dim=0)
return a[:2]
def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor:
if z.dim() == 4 and z.shape[0] == 1:
z = z.squeeze(0)
if z.dim() == 2:
z = z.unsqueeze(0)
return z
class AceStepUnit_ContextLatentBuilder(PipelineUnit): refer_encode_cache: Dict[int, torch.Tensor] = {}
"""Build context_latents from src_latents and chunk_masks. for batch_idx, refer_audios in enumerate(refer_audioss):
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
refer_audio_latent = _ensure_latent_3d(pipe.silence_latent[:, :750, :])
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
else:
# TODO: check
for refer_audio in refer_audios:
cache_key = refer_audio.data_ptr()
if cache_key in refer_encode_cache:
refer_audio_latent = refer_encode_cache[cache_key].clone()
else:
refer_audio = _normalize_audio_2d(refer_audio)
refer_audio_latent = pipe.vae.encode(refer_audio)
refer_audio_latent = refer_audio_latent.to(dtype=pipe.torch_dtype, device=pipe.device)
if refer_audio_latent.dim() == 2:
refer_audio_latent = refer_audio_latent.unsqueeze(0)
refer_audio_latent = _ensure_latent_3d(refer_audio_latent.transpose(1, 2))
refer_encode_cache[cache_key] = refer_audio_latent
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
Target library reference: `prepare_condition()` line 1699: refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1) refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
return refer_audio_latents, refer_audio_order_mask
context_latents is the SAME for positive and negative CFG paths
(it comes from src_latents + chunk_masks, not from text encoding).
So this is a普通模式 Unit — outputs go to inputs_shared.
"""
def __init__(self):
super().__init__(
input_params=("src_latents", "chunk_masks"),
output_params=("context_latents", "attention_mask"),
)
def process(self, pipe, src_latents, chunk_masks):
# context_latents: cat([src_latents, chunk_masks], dim=-1) → [B, T, 128]
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
# attention_mask for the DiT: ones [B, T]
# The target library uses this for cross-attention with context_latents
attention_mask = torch.ones(src_latents.shape[0], src_latents.shape[1],
device=pipe.device, dtype=pipe.torch_dtype)
return {"context_latents": context_latents, "attention_mask": attention_mask}
class AceStepUnit_ConditionEmbedder(PipelineUnit): class AceStepUnit_ConditionEmbedder(PipelineUnit):
"""Generate encoder_hidden_states via ACEStepConditioner.
Target library reference: `prepare_condition()` line 1674-1681:
encoder_hidden_states, encoder_attention_mask = self.encoder(...)
Uses seperate_cfg mode:
- Positive: encode with full condition (text + lyrics + reference audio)
- Negative: replace text with null_condition_emb, keep lyrics/timbre same
context_latents is handled by ContextLatentBuilder (普通模式), not here.
"""
def __init__(self): def __init__(self):
super().__init__( super().__init__(
seperate_cfg=True, seperate_cfg=True,
input_params_posi={ input_params_posi={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
"text_hidden_states": "text_hidden_states", input_params_nega={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
"text_attention_mask": "text_attention_mask", input_params=("reference_latents", "refer_audio_order_mask"),
"lyric_hidden_states": "lyric_hidden_states", output_params=("encoder_hidden_states", "encoder_attention_mask"),
"lyric_attention_mask": "lyric_attention_mask", onload_model_names=("conditioner",),
"reference_audio": "reference_audio",
"refer_audio_order_mask": "refer_audio_order_mask",
},
input_params_nega={},
input_params=("cfg_scale",),
output_params=(
"encoder_hidden_states", "encoder_attention_mask",
"negative_encoder_hidden_states", "negative_encoder_attention_mask",
),
onload_model_names=("conditioner",)
) )
def _prepare_condition(self, pipe, text_hidden_states, text_attention_mask, def process(self, pipe, text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, reference_latents, refer_audio_order_mask):
lyric_hidden_states, lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=None,
refer_audio_order_mask=None):
"""Call ACEStepConditioner forward to produce encoder_hidden_states."""
pipe.load_models_to_device(['conditioner']) pipe.load_models_to_device(['conditioner'])
# Handle reference audio
if refer_audio_acoustic_hidden_states_packed is None:
# No reference audio: create 2D packed zeros [N=1, d=64]
# TimbreEncoder.unpack expects [N, d], not [B, T, d]
refer_audio_acoustic_hidden_states_packed = torch.zeros(
1, 64, device=pipe.device, dtype=pipe.torch_dtype
)
refer_audio_order_mask = torch.LongTensor([0]).to(pipe.device)
encoder_hidden_states, encoder_attention_mask = pipe.conditioner( encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
text_hidden_states=text_hidden_states, text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask, text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states, lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask, lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, reference_latents=reference_latents,
refer_audio_order_mask=refer_audio_order_mask, refer_audio_order_mask=refer_audio_order_mask,
) )
return {"encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask}
return encoder_hidden_states, encoder_attention_mask
def _prepare_negative_condition(self, pipe, lyric_hidden_states, lyric_attention_mask, class AceStepUnit_ContextLatentBuilder(PipelineUnit):
refer_audio_acoustic_hidden_states_packed=None, def __init__(self):
refer_audio_order_mask=None): super().__init__(
"""Generate negative condition using null_condition_emb.""" input_params=("duration", "src_audio"),
if pipe.conditioner is None or not hasattr(pipe.conditioner, 'null_condition_emb'): output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
return None, None
null_emb = pipe.conditioner.null_condition_emb # [1, 1, hidden_size]
bsz = 1
if lyric_hidden_states is not None:
bsz = lyric_hidden_states.shape[0]
null_hidden_states = null_emb.expand(bsz, -1, -1)
null_attn_mask = torch.ones(bsz, 1, device=pipe.device, dtype=pipe.torch_dtype)
# For negative: use null_condition_emb as text, keep lyrics and timbre
neg_encoder_hidden_states, neg_encoder_attention_mask = pipe.conditioner(
text_hidden_states=null_hidden_states,
text_attention_mask=null_attn_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
) )
return neg_encoder_hidden_states, neg_encoder_attention_mask def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
available = pipe.silence_latent.shape[1]
if length <= available:
return pipe.silence_latent[0, :length, :]
repeats = (length + available - 1) // available
tiled = pipe.silence_latent[0].repeat(repeats, 1)
return tiled[:length, :]
def process(self, pipe, text_hidden_states, text_attention_mask, def process(self, pipe, duration, src_audio):
lyric_hidden_states, lyric_attention_mask, if src_audio is not None:
reference_audio=None, refer_audio_order_mask=None, raise NotImplementedError("Src audio conditioning is not implemented yet. Please set src_audio to None.")
negative_prompt=None, cfg_scale=1.0): else:
max_latent_length = duration * pipe.sample_rate // 1920
# Positive condition src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
pos_enc_hs, pos_enc_mask = self._prepare_condition( chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
pipe, text_hidden_states, text_attention_mask, attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
lyric_hidden_states, lyric_attention_mask, context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
None, refer_audio_order_mask, return {"context_latents": context_latents, "attention_mask": attention_mask}
)
# Negative condition: only needed when CFG is active (cfg_scale > 1.0)
# For cfg_scale=1.0 (turbo), skip to avoid null_condition_emb dimension mismatch
result = {
"encoder_hidden_states": pos_enc_hs,
"encoder_attention_mask": pos_enc_mask,
}
if cfg_scale > 1.0:
neg_enc_hs, neg_enc_mask = self._prepare_negative_condition(
pipe, lyric_hidden_states, lyric_attention_mask,
None, refer_audio_order_mask,
)
if neg_enc_hs is not None:
result["negative_encoder_hidden_states"] = neg_enc_hs
result["negative_encoder_attention_mask"] = neg_enc_mask
return result
class AceStepUnit_NoiseInitializer(PipelineUnit): class AceStepUnit_NoiseInitializer(PipelineUnit):
"""Generate initial noise tensor.
Target library reference: `prepare_noise()` line 1781-1818:
src_latents_shape = (bsz, context_latents.shape[1], context_latents.shape[-1] // 2)
Noise shape = [B, T, context_latents.shape[-1] // 2] = [B, T, 128 // 2] = [B, T, 64]
"""
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("seed", "seq_len", "rand_device", "context_latents"), input_params=("context_latents", "seed", "rand_device"),
output_params=("noise",), output_params=("noise",),
) )
def process(self, pipe, seed, seq_len, rand_device, context_latents): def process(self, pipe, context_latents, seed, rand_device):
# Noise shape: [B, T, context_latents.shape[-1] // 2] src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
# context_latents = [B, T, 128] → noise = [B, T, 64] noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
# This matches the target library's prepare_noise() at line 1796
noise_shape = (context_latents.shape[0], context_latents.shape[1],
context_latents.shape[-1] // 2)
noise = pipe.generate_noise(
noise_shape,
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
)
return {"noise": noise} return {"noise": noise}
class AceStepUnit_InputAudioEmbedder(PipelineUnit): class AceStepUnit_InputAudioEmbedder(PipelineUnit):
"""Set up latents for denoise loop.
For text2music (no input audio): latents = noise, input_latents = None.
Target library reference: `generate_audio()` line 1972:
xt = noise (when cover_noise_strength == 0)
"""
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("noise",), input_params=("noise", "input_audio"),
output_params=("latents", "input_latents"), output_params=("latents", "input_latents"),
) )
def process(self, pipe, noise): def process(self, pipe, noise, input_audio):
# For text2music: start from pure noise if input_audio is None:
return {"latents": noise}
# TODO: support for train
return {"latents": noise, "input_latents": None} return {"latents": noise, "input_latents": None}
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("audio_codes", "seq_len", "silence_latent"),
output_params=("lm_hints_25Hz",),
onload_model_names=("tokenizer_model",),
)
@staticmethod
def _parse_audio_code_string(code_str: str) -> list:
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
if not code_str:
return []
codes = []
max_audio_code = 63999
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
code_value = int(x)
codes.append(max(0, min(code_value, max_audio_code)))
return codes
def process(self, pipe, audio_codes, seq_len, silence_latent):
if audio_codes is None or not audio_codes.strip():
return {"lm_hints_25Hz": None}
code_ids = self._parse_audio_code_string(audio_codes)
if len(code_ids) == 0:
return {"lm_hints_25Hz": None}
pipe.load_models_to_device(["tokenizer_model"])
quantizer = pipe.tokenizer_model.tokenizer.quantizer
detokenizer = pipe.tokenizer_model.detokenizer
indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
quantized = quantizer.get_output_from_indices(indices) # [1, N, 2048]
if quantized.dtype != pipe.torch_dtype:
quantized = quantized.to(pipe.torch_dtype)
lm_hints = detokenizer(quantized) # [1, N*5, 64]
# Pad or truncate to seq_len
current_len = lm_hints.shape[1]
if current_len < seq_len:
pad_len = seq_len - current_len
pad = silence_latent[:, :pad_len, :]
lm_hints = torch.cat([lm_hints, pad], dim=1)
elif current_len > seq_len:
lm_hints = lm_hints[:, :seq_len, :]
return {"lm_hints_25Hz": lm_hints}
def model_fn_ace_step( def model_fn_ace_step(
dit: AceStepDiTModel, dit: AceStepDiTModel,
latents=None, latents=None,
@@ -468,49 +486,11 @@ def model_fn_ace_step(
encoder_attention_mask=None, encoder_attention_mask=None,
context_latents=None, context_latents=None,
attention_mask=None, attention_mask=None,
past_key_values=None, use_gradient_checkpointing=False,
negative_encoder_hidden_states=None, use_gradient_checkpointing_offload=False,
negative_encoder_attention_mask=None,
negative_context_latents=None,
**kwargs, **kwargs,
): ):
"""Model function for ACE-Step DiT forward. timestep = timestep.unsqueeze(0)
Timestep is already in [0, 1] range — no scaling needed.
Target library reference: `generate_audio()` line 2009-2020:
decoder_outputs = self.decoder(
hidden_states=x, timestep=t_curr_tensor, timestep_r=t_curr_tensor,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
use_cache=True, past_key_values=past_key_values,
)
Args:
dit: AceStepDiTModel
latents: [B, T, 64] noise/latent tensor (same shape as src_latents)
timestep: scalar tensor in [0, 1]
encoder_hidden_states: [B, T_text, 2048] condition from Conditioner
(positive or negative depending on CFG pass — the cfg_guided_model_fn
passes inputs_posi for positive, inputs_nega for negative)
encoder_attention_mask: [B, T_text]
context_latents: [B, T, 128] = cat([src_latents, chunk_masks], dim=-1)
(same for both CFG+/- paths in text2music mode)
attention_mask: [B, T] ones mask for DiT
past_key_values: EncoderDecoderCache for KV caching
The DiT internally concatenates: cat([context_latents, latents], dim=-1) = [B, T, 192]
as the actual input (128 + 64 = 192 channels).
"""
# ACE-Step uses timestep directly in [0, 1] range — no /1000 scaling
timestep = timestep.squeeze()
# Expand timestep to match batch size
bsz = latents.shape[0]
timestep = timestep.expand(bsz)
decoder_outputs = dit( decoder_outputs = dit(
hidden_states=latents, hidden_states=latents,
timestep=timestep, timestep=timestep,
@@ -519,9 +499,5 @@ def model_fn_ace_step(
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents, context_latents=context_latents,
use_cache=True, )[0]
past_key_values=past_key_values, return decoder_outputs
)
# Return velocity prediction (first element of decoder_outputs)
return decoder_outputs[0]

View File

@@ -99,6 +99,7 @@ def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend
""" """
if waveform.dim() == 3: if waveform.dim() == 3:
waveform = waveform[0] waveform = waveform[0]
waveform.cpu()
if backend == "torchcodec": if backend == "torchcodec":
from torchcodec.encoders import AudioEncoder from torchcodec.encoders import AudioEncoder

View File

@@ -2,8 +2,8 @@
Ace-Step 1.5 — Text-to-Music with Simple Mode (LLM expansion). Ace-Step 1.5 — Text-to-Music with Simple Mode (LLM expansion).
Uses the ACE-Step LLM to expand a simple description into structured Uses the ACE-Step LLM to expand a simple description into structured
parameters (caption, lyrics, bpm, keyscale, etc.), then feeds them parameters (caption, lyrics, bpm, keyscale, etc.) AND audio codes,
to the DiffSynth Pipeline. then feeds them to the DiffSynth Pipeline.
The LLM expansion uses the target library's LLMHandler. If vLLM is The LLM expansion uses the target library's LLMHandler. If vLLM is
not available, it falls back to using pre-structured parameters. not available, it falls back to using pre-structured parameters.
@@ -47,11 +47,14 @@ def try_load_llm_handler(checkpoint_dir: str, lm_model_path: str = "acestep-5Hz-
def expand_with_llm(llm_handler, description: str, duration: float = 30.0): def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
"""Expand a simple description using LLM Chain-of-Thought.""" """Expand a simple description using LLM Chain-of-Thought.
Returns (params_dict, audio_codes_string).
"""
result = llm_handler.generate_with_stop_condition( result = llm_handler.generate_with_stop_condition(
caption=description, caption=description,
lyrics="", lyrics="",
infer_type="dit", # metadata only infer_type="dit", # metadata + audio codes
temperature=0.85, temperature=0.85,
cfg_scale=1.0, cfg_scale=1.0,
use_cot_metas=True, use_cot_metas=True,
@@ -62,7 +65,7 @@ def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
if result.get("success") and result.get("metadata"): if result.get("success") and result.get("metadata"):
meta = result["metadata"] meta = result["metadata"]
return { params = {
"caption": meta.get("caption", description), "caption": meta.get("caption", description),
"lyrics": meta.get("lyrics", ""), "lyrics": meta.get("lyrics", ""),
"bpm": meta.get("bpm", 100), "bpm": meta.get("bpm", 100),
@@ -71,9 +74,11 @@ def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
"timesignature": meta.get("timesignature", "4"), "timesignature": meta.get("timesignature", "4"),
"duration": meta.get("duration", duration), "duration": meta.get("duration", duration),
} }
audio_codes = result.get("audio_codes", "")
return params, audio_codes
print(f"[Simple Mode] LLM expansion failed: {result.get('error', 'unknown')}") print(f"[Simple Mode] LLM expansion failed: {result.get('error', 'unknown')}")
return None return None, ""
def fallback_expand(description: str, duration: float = 30.0): def fallback_expand(description: str, duration: float = 30.0):
@@ -87,7 +92,7 @@ def fallback_expand(description: str, duration: float = 30.0):
"language": "en", "language": "en",
"timesignature": "4", "timesignature": "4",
"duration": duration, "duration": duration,
} }, ""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -114,13 +119,13 @@ def main():
lm_model_path="acestep-5Hz-lm-1.7B", lm_model_path="acestep-5Hz-lm-1.7B",
) )
# 2. Expand parameters # 2. Expand parameters + audio codes
if llm_ok: if llm_ok:
params = expand_with_llm(llm_handler, description, duration=duration) params, audio_codes = expand_with_llm(llm_handler, description, duration=duration)
if params is None: if params is None:
params = fallback_expand(description, duration) params, audio_codes = fallback_expand(description, duration)
else: else:
params = fallback_expand(description, duration) params, audio_codes = fallback_expand(description, duration)
print(f"\n[Simple Mode] Parameters:") print(f"\n[Simple Mode] Parameters:")
print(f" Caption: {params['caption'][:100]}...") print(f" Caption: {params['caption'][:100]}...")
@@ -128,6 +133,7 @@ def main():
print(f" BPM: {params['bpm']}, Keyscale: {params['keyscale']}") print(f" BPM: {params['bpm']}, Keyscale: {params['keyscale']}")
print(f" Language: {params['language']}, Time Sig: {params['timesignature']}") print(f" Language: {params['language']}, Time Sig: {params['timesignature']}")
print(f" Duration: {params['duration']}s") print(f" Duration: {params['duration']}s")
print(f" Audio codes: {len(audio_codes)} chars" if audio_codes else " Audio codes: None (fallback)")
# 3. Load Pipeline # 3. Load Pipeline
print(f"\n[Pipeline] Loading Ace-Step 1.5 (turbo)...") print(f"\n[Pipeline] Loading Ace-Step 1.5 (turbo)...")
@@ -141,21 +147,17 @@ def main():
), ),
ModelConfig( ModelConfig(
model_id="ACE-Step/Ace-Step1.5", model_id="ACE-Step/Ace-Step1.5",
origin_file_pattern="acestep-v15-turbo/model.safetensors" origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
), ),
ModelConfig( ModelConfig(
model_id="ACE-Step/Ace-Step1.5", model_id="ACE-Step/Ace-Step1.5",
origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors" origin_file_pattern="vae/diffusion_pytorch_model.safetensors"
), ),
], ],
tokenizer_config=ModelConfig( text_tokenizer_config=ModelConfig(
model_id="ACE-Step/Ace-Step1.5", model_id="ACE-Step/Ace-Step1.5",
origin_file_pattern="Qwen3-Embedding-0.6B/" origin_file_pattern="Qwen3-Embedding-0.6B/"
), ),
vae_config=ModelConfig(
model_id="ACE-Step/Ace-Step1.5",
origin_file_pattern="vae/"
),
) )
# 4. Generate # 4. Generate
@@ -164,6 +166,7 @@ def main():
prompt=params["caption"], prompt=params["caption"],
lyrics=params["lyrics"], lyrics=params["lyrics"],
duration=params["duration"], duration=params["duration"],
audio_codes=audio_codes if audio_codes else None,
seed=42, seed=42,
num_inference_steps=8, num_inference_steps=8,
cfg_scale=1.0, cfg_scale=1.0,

View File

@@ -1,16 +1,6 @@
"""
Ace-Step 1.5 — Text-to-Music (Turbo) inference example.
Demonstrates the standard text2music pipeline with structured parameters
(caption, lyrics, duration, etc.) — no LLM expansion needed.
For Simple Mode (LLM expands a short description), see:
- Ace-Step1.5-SimpleMode.py
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
import torch import torch
import soundfile as sf from diffsynth.utils.data.audio import save_audio
pipe = AceStepPipeline.from_pretrained( pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
@@ -35,29 +25,21 @@ pipe = AceStepPipeline.from_pretrained(
), ),
) )
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline." prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = """[Intro - Synth Brass Fanfare] lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
[Verse 1]
黑夜里的风吹过耳畔
甜蜜时光转瞬即逝
脚步飘摇在星光上
[Chorus]
心电感应在震动间
拥抱未来勇敢冒险
[Outro - Instrumental]"""
audio = pipe( audio = pipe(
prompt=prompt, prompt=prompt,
lyrics=lyrics, lyrics=lyrics,
duration=30.0, duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42, seed=42,
num_inference_steps=8, num_inference_steps=8,
cfg_scale=1.0, cfg_scale=1.0,
shift=3.0,
) )
sf.write("Ace-Step1.5.wav", audio.cpu().numpy(), pipe.sample_rate) save_audio(audio.cpu(), pipe.vae.sampling_rate, "Ace-Step1.5.wav")
print(f"Saved to Ace-Step1.5.wav, shape: {audio.shape}, duration: {audio.shape[-1] / pipe.sample_rate:.1f}s") print(f"Saved to Ace-Step1.5.wav, shape: {audio.shape}, duration: {audio.shape[-1] / pipe.sample_rate:.1f}s")