Files
DiffSynth-Studio/diffsynth/pipelines/ace_step.py
2026-04-17 17:06:26 +08:00

528 lines
21 KiB
Python

"""
ACE-Step Pipeline for DiffSynth-Studio.
Text-to-Music generation pipeline using ACE-Step 1.5 model.
"""
import torch
from typing import Optional
from tqdm import tqdm
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.ace_step_dit import AceStepDiTModel
from ..models.ace_step_conditioner import AceStepConditionEncoder
from ..models.ace_step_text_encoder import AceStepTextEncoder
from ..models.ace_step_vae import AceStepVAE
class AceStepPipeline(BasePipeline):
"""Pipeline for ACE-Step text-to-music generation."""
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device,
torch_dtype=torch_dtype,
height_division_factor=1,
width_division_factor=1,
)
self.scheduler = FlowMatchScheduler("ACE-Step")
self.text_encoder: AceStepTextEncoder = None
self.conditioner: AceStepConditionEncoder = None
self.dit: AceStepDiTModel = None
self.vae = None # AutoencoderOobleck (diffusers) or AceStepVAE
# Unit chain order — 7 units total
#
# 1. ShapeChecker: duration → seq_len
# 2. PromptEmbedder: prompt/lyrics → text/lyric embeddings (shared for CFG)
# 3. SilenceLatentInitializer: seq_len → src_latents + chunk_masks
# 4. ContextLatentBuilder: src_latents + chunk_masks → context_latents (shared, same for CFG+)
# 5. ConditionEmbedder: text/lyric → encoder_hidden_states (separate for CFG+/-)
# 6. NoiseInitializer: context_latents → noise
# 7. InputAudioEmbedder: noise → latents
#
# ContextLatentBuilder runs before ConditionEmbedder so that
# context_latents is available for noise shape computation.
self.in_iteration_models = ("dit",)
self.units = [
AceStepUnit_ShapeChecker(),
AceStepUnit_PromptEmbedder(),
AceStepUnit_SilenceLatentInitializer(),
AceStepUnit_ContextLatentBuilder(),
AceStepUnit_ConditionEmbedder(),
AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(),
]
self.model_fn = model_fn_ace_step
self.compilable_models = ["dit"]
self.sample_rate = 48000
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: str = get_device_type(),
model_configs: list[ModelConfig] = [],
text_tokenizer_config: ModelConfig = None,
vram_limit: float = None,
):
"""Load pipeline from pretrained checkpoints."""
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("ace_step_text_encoder")
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
pipe.dit = model_pool.fetch_model("ace_step_dit")
pipe.vae = model_pool.fetch_model("ace_step_vae")
if text_tokenizer_config is not None:
text_tokenizer_config.download_if_necessary()
from transformers import AutoTokenizer
pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Lyrics
lyrics: str = "",
# Reference audio (optional, for timbre conditioning)
reference_audio = None,
# Shape
duration: float = 60.0,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 8,
# Scheduler-specific parameters
shift: float = 3.0,
# Progress
progress_bar_cmd=tqdm,
):
# 1. Scheduler
self.scheduler.set_timesteps(
num_inference_steps=num_inference_steps,
denoising_strength=1.0,
shift=shift,
)
# 2. 三字典输入
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"lyrics": lyrics,
"reference_audio": reference_audio,
"duration": duration,
"seed": seed,
"rand_device": rand_device,
"num_inference_steps": num_inference_steps,
"shift": shift,
}
# 3. Unit 链执行
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# 4. Denoise loop
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
)
# 5. VAE 解码
self.load_models_to_device(['vae'])
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
latents = inputs_shared["latents"].transpose(1, 2)
vae_output = self.vae.decode(latents)
# VAE returns OobleckDecoderOutput with .sample attribute
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
audio = self.output_audio_format_check(audio_output)
self.load_models_to_device([])
return audio
def output_audio_format_check(self, audio_output):
"""Convert VAE output to standard audio format [C, T], float32.
VAE decode outputs [B, C, T] (audio waveform).
We squeeze batch dim and return [C, T].
"""
if audio_output.ndim == 3:
audio_output = audio_output.squeeze(0)
return audio_output.float()
class AceStepUnit_ShapeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
input_params=("duration",),
output_params=("duration", "seq_len"),
)
def process(self, pipe, duration):
# ACE-Step: 25 Hz latent rate
seq_len = int(duration * 25)
return {"duration": duration, "seq_len": seq_len}
class AceStepUnit_PromptEmbedder(PipelineUnit):
"""Encode prompt and lyrics using Qwen3-Embedding.
Uses seperate_cfg=True to read prompt from inputs_posi (not inputs_shared).
The negative condition uses null_condition_emb (handled by ConditionEmbedder),
so negative text encoding is not needed here.
"""
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={},
input_params=("lyrics",),
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
onload_model_names=("text_encoder",)
)
def _encode_text(self, pipe, text):
"""Encode text using Qwen3-Embedding → [B, T, 1024]."""
if pipe.tokenizer is None:
return None, None
text_inputs = pipe.tokenizer(
text,
padding="max_length",
max_length=512,
truncation=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(pipe.device)
attention_mask = text_inputs.attention_mask.to(pipe.device)
hidden_states = pipe.text_encoder(input_ids, attention_mask)
return hidden_states, attention_mask
def process(self, pipe, prompt, lyrics, negative_prompt=None):
pipe.load_models_to_device(['text_encoder'])
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt)
# Lyrics encoding — use empty string if not provided
lyric_text = lyrics if lyrics else ""
lyric_hidden_states, lyric_attention_mask = self._encode_text(pipe, lyric_text)
if text_hidden_states is not None and lyric_hidden_states is not None:
return {
"text_hidden_states": text_hidden_states,
"text_attention_mask": text_attention_mask,
"lyric_hidden_states": lyric_hidden_states,
"lyric_attention_mask": lyric_attention_mask,
}
return {}
class AceStepUnit_SilenceLatentInitializer(PipelineUnit):
"""Generate silence latent (all zeros) and chunk_masks for text2music.
Target library reference: `prepare_condition()` line 1698-1699:
context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
For text2music mode:
- src_latents = zeros [B, T, 64] (VAE latent dimension)
- chunk_masks = ones [B, T, 64] (full visibility mask for text2music)
- context_latents = [B, T, 128] (concat of src_latents + chunk_masks)
"""
def __init__(self):
super().__init__(
input_params=("seq_len",),
output_params=("silence_latent", "src_latents", "chunk_masks"),
)
def process(self, pipe, seq_len):
# silence_latent shape: [B, T, 64] — 64 is the VAE latent dimension
silence_latent = torch.zeros(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype)
# For text2music: src_latents = silence_latent
src_latents = silence_latent.clone()
# chunk_masks: [B, T, 64] of ones (same shape as src_latents)
# In text2music mode (is_covers=0), chunk_masks are all 1.0
# This matches the target library's behavior at line 1699
chunk_masks = torch.ones(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype)
return {"silence_latent": silence_latent, "src_latents": src_latents, "chunk_masks": chunk_masks}
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
"""Build context_latents from src_latents and chunk_masks.
Target library reference: `prepare_condition()` line 1699:
context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
context_latents is the SAME for positive and negative CFG paths
(it comes from src_latents + chunk_masks, not from text encoding).
So this is a普通模式 Unit — outputs go to inputs_shared.
"""
def __init__(self):
super().__init__(
input_params=("src_latents", "chunk_masks"),
output_params=("context_latents", "attention_mask"),
)
def process(self, pipe, src_latents, chunk_masks):
# context_latents: cat([src_latents, chunk_masks], dim=-1) → [B, T, 128]
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
# attention_mask for the DiT: ones [B, T]
# The target library uses this for cross-attention with context_latents
attention_mask = torch.ones(src_latents.shape[0], src_latents.shape[1],
device=pipe.device, dtype=pipe.torch_dtype)
return {"context_latents": context_latents, "attention_mask": attention_mask}
class AceStepUnit_ConditionEmbedder(PipelineUnit):
"""Generate encoder_hidden_states via ACEStepConditioner.
Target library reference: `prepare_condition()` line 1674-1681:
encoder_hidden_states, encoder_attention_mask = self.encoder(...)
Uses seperate_cfg mode:
- Positive: encode with full condition (text + lyrics + reference audio)
- Negative: replace text with null_condition_emb, keep lyrics/timbre same
context_latents is handled by ContextLatentBuilder (普通模式), not here.
"""
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={
"text_hidden_states": "text_hidden_states",
"text_attention_mask": "text_attention_mask",
"lyric_hidden_states": "lyric_hidden_states",
"lyric_attention_mask": "lyric_attention_mask",
"reference_audio": "reference_audio",
"refer_audio_order_mask": "refer_audio_order_mask",
},
input_params_nega={},
input_params=("cfg_scale",),
output_params=(
"encoder_hidden_states", "encoder_attention_mask",
"negative_encoder_hidden_states", "negative_encoder_attention_mask",
),
onload_model_names=("conditioner",)
)
def _prepare_condition(self, pipe, text_hidden_states, text_attention_mask,
lyric_hidden_states, lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=None,
refer_audio_order_mask=None):
"""Call ACEStepConditioner forward to produce encoder_hidden_states."""
pipe.load_models_to_device(['conditioner'])
# Handle reference audio
if refer_audio_acoustic_hidden_states_packed is None:
# No reference audio: create 2D packed zeros [N=1, d=64]
# TimbreEncoder.unpack expects [N, d], not [B, T, d]
refer_audio_acoustic_hidden_states_packed = torch.zeros(
1, 64, device=pipe.device, dtype=pipe.torch_dtype
)
refer_audio_order_mask = torch.LongTensor([0]).to(pipe.device)
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
)
return encoder_hidden_states, encoder_attention_mask
def _prepare_negative_condition(self, pipe, lyric_hidden_states, lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=None,
refer_audio_order_mask=None):
"""Generate negative condition using null_condition_emb."""
if pipe.conditioner is None or not hasattr(pipe.conditioner, 'null_condition_emb'):
return None, None
null_emb = pipe.conditioner.null_condition_emb # [1, 1, hidden_size]
bsz = 1
if lyric_hidden_states is not None:
bsz = lyric_hidden_states.shape[0]
null_hidden_states = null_emb.expand(bsz, -1, -1)
null_attn_mask = torch.ones(bsz, 1, device=pipe.device, dtype=pipe.torch_dtype)
# For negative: use null_condition_emb as text, keep lyrics and timbre
neg_encoder_hidden_states, neg_encoder_attention_mask = pipe.conditioner(
text_hidden_states=null_hidden_states,
text_attention_mask=null_attn_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
)
return neg_encoder_hidden_states, neg_encoder_attention_mask
def process(self, pipe, text_hidden_states, text_attention_mask,
lyric_hidden_states, lyric_attention_mask,
reference_audio=None, refer_audio_order_mask=None,
negative_prompt=None, cfg_scale=1.0):
# Positive condition
pos_enc_hs, pos_enc_mask = self._prepare_condition(
pipe, text_hidden_states, text_attention_mask,
lyric_hidden_states, lyric_attention_mask,
None, refer_audio_order_mask,
)
# Negative condition: only needed when CFG is active (cfg_scale > 1.0)
# For cfg_scale=1.0 (turbo), skip to avoid null_condition_emb dimension mismatch
result = {
"encoder_hidden_states": pos_enc_hs,
"encoder_attention_mask": pos_enc_mask,
}
if cfg_scale > 1.0:
neg_enc_hs, neg_enc_mask = self._prepare_negative_condition(
pipe, lyric_hidden_states, lyric_attention_mask,
None, refer_audio_order_mask,
)
if neg_enc_hs is not None:
result["negative_encoder_hidden_states"] = neg_enc_hs
result["negative_encoder_attention_mask"] = neg_enc_mask
return result
class AceStepUnit_NoiseInitializer(PipelineUnit):
"""Generate initial noise tensor.
Target library reference: `prepare_noise()` line 1781-1818:
src_latents_shape = (bsz, context_latents.shape[1], context_latents.shape[-1] // 2)
Noise shape = [B, T, context_latents.shape[-1] // 2] = [B, T, 128 // 2] = [B, T, 64]
"""
def __init__(self):
super().__init__(
input_params=("seed", "seq_len", "rand_device", "context_latents"),
output_params=("noise",),
)
def process(self, pipe, seed, seq_len, rand_device, context_latents):
# Noise shape: [B, T, context_latents.shape[-1] // 2]
# context_latents = [B, T, 128] → noise = [B, T, 64]
# This matches the target library's prepare_noise() at line 1796
noise_shape = (context_latents.shape[0], context_latents.shape[1],
context_latents.shape[-1] // 2)
noise = pipe.generate_noise(
noise_shape,
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
)
return {"noise": noise}
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
"""Set up latents for denoise loop.
For text2music (no input audio): latents = noise, input_latents = None.
Target library reference: `generate_audio()` line 1972:
xt = noise (when cover_noise_strength == 0)
"""
def __init__(self):
super().__init__(
input_params=("noise",),
output_params=("latents", "input_latents"),
)
def process(self, pipe, noise):
# For text2music: start from pure noise
return {"latents": noise, "input_latents": None}
def model_fn_ace_step(
dit: AceStepDiTModel,
latents=None,
timestep=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
context_latents=None,
attention_mask=None,
past_key_values=None,
negative_encoder_hidden_states=None,
negative_encoder_attention_mask=None,
negative_context_latents=None,
**kwargs,
):
"""Model function for ACE-Step DiT forward.
Timestep is already in [0, 1] range — no scaling needed.
Target library reference: `generate_audio()` line 2009-2020:
decoder_outputs = self.decoder(
hidden_states=x, timestep=t_curr_tensor, timestep_r=t_curr_tensor,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
use_cache=True, past_key_values=past_key_values,
)
Args:
dit: AceStepDiTModel
latents: [B, T, 64] noise/latent tensor (same shape as src_latents)
timestep: scalar tensor in [0, 1]
encoder_hidden_states: [B, T_text, 2048] condition from Conditioner
(positive or negative depending on CFG pass — the cfg_guided_model_fn
passes inputs_posi for positive, inputs_nega for negative)
encoder_attention_mask: [B, T_text]
context_latents: [B, T, 128] = cat([src_latents, chunk_masks], dim=-1)
(same for both CFG+/- paths in text2music mode)
attention_mask: [B, T] ones mask for DiT
past_key_values: EncoderDecoderCache for KV caching
The DiT internally concatenates: cat([context_latents, latents], dim=-1) = [B, T, 192]
as the actual input (128 + 64 = 192 channels).
"""
# ACE-Step uses timestep directly in [0, 1] range — no /1000 scaling
timestep = timestep.squeeze()
# Expand timestep to match batch size
bsz = latents.shape[0]
timestep = timestep.expand(bsz)
decoder_outputs = dit(
hidden_states=latents,
timestep=timestep,
timestep_r=timestep,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
use_cache=True,
past_key_values=past_key_values,
)
# Return velocity prediction (first element of decoder_outputs)
return decoder_outputs[0]