mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
acestep t2m
This commit is contained in:
@@ -157,10 +157,10 @@ class FlowMatchScheduler():
|
|||||||
"""
|
"""
|
||||||
num_train_timesteps = 1000
|
num_train_timesteps = 1000
|
||||||
sigma_start = denoising_strength
|
sigma_start = denoising_strength
|
||||||
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps)
|
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
|
||||||
if shift is not None and shift != 1.0:
|
if shift is not None and shift != 1.0:
|
||||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
timesteps = sigmas # ACE-Step uses [0, 1] range directly
|
timesteps = sigmas * num_train_timesteps
|
||||||
return sigmas, timesteps
|
return sigmas, timesteps
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -540,17 +540,9 @@ class AceStepTimbreEncoder(nn.Module):
|
|||||||
) -> BaseModelOutput:
|
) -> BaseModelOutput:
|
||||||
inputs_embeds = refer_audio_acoustic_hidden_states_packed
|
inputs_embeds = refer_audio_acoustic_hidden_states_packed
|
||||||
inputs_embeds = self.embed_tokens(inputs_embeds)
|
inputs_embeds = self.embed_tokens(inputs_embeds)
|
||||||
# Handle 2D (packed) or 3D (batched) input
|
seq_len = inputs_embeds.shape[1]
|
||||||
is_packed = inputs_embeds.dim() == 2
|
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
|
||||||
if is_packed:
|
position_ids = cache_position.unsqueeze(0)
|
||||||
seq_len = inputs_embeds.shape[0]
|
|
||||||
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
|
|
||||||
position_ids = cache_position.unsqueeze(0)
|
|
||||||
inputs_embeds = inputs_embeds.unsqueeze(0)
|
|
||||||
else:
|
|
||||||
seq_len = inputs_embeds.shape[1]
|
|
||||||
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
|
|
||||||
position_ids = cache_position.unsqueeze(0)
|
|
||||||
|
|
||||||
dtype = inputs_embeds.dtype
|
dtype = inputs_embeds.dtype
|
||||||
device = inputs_embeds.device
|
device = inputs_embeds.device
|
||||||
@@ -586,9 +578,8 @@ class AceStepTimbreEncoder(nn.Module):
|
|||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
hidden_states = hidden_states[:, 0, :]
|
||||||
# For packed input: reshape [1, T, D] -> [T, D] for unpacking
|
# For packed input: reshape [1, T, D] -> [T, D] for unpacking
|
||||||
if is_packed:
|
|
||||||
hidden_states = hidden_states.squeeze(0)
|
|
||||||
timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
|
timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
|
||||||
return timbre_embs_unpack, timbre_embs_mask
|
return timbre_embs_unpack, timbre_embs_mask
|
||||||
|
|
||||||
@@ -686,7 +677,7 @@ class AceStepConditionEncoder(nn.Module):
|
|||||||
text_attention_mask: Optional[torch.Tensor] = None,
|
text_attention_mask: Optional[torch.Tensor] = None,
|
||||||
lyric_hidden_states: Optional[torch.LongTensor] = None,
|
lyric_hidden_states: Optional[torch.LongTensor] = None,
|
||||||
lyric_attention_mask: Optional[torch.Tensor] = None,
|
lyric_attention_mask: Optional[torch.Tensor] = None,
|
||||||
refer_audio_acoustic_hidden_states_packed: Optional[torch.Tensor] = None,
|
reference_latents: Optional[torch.Tensor] = None,
|
||||||
refer_audio_order_mask: Optional[torch.LongTensor] = None,
|
refer_audio_order_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
text_hidden_states = self.text_projector(text_hidden_states)
|
text_hidden_states = self.text_projector(text_hidden_states)
|
||||||
@@ -695,11 +686,7 @@ class AceStepConditionEncoder(nn.Module):
|
|||||||
attention_mask=lyric_attention_mask,
|
attention_mask=lyric_attention_mask,
|
||||||
)
|
)
|
||||||
lyric_hidden_states = lyric_encoder_outputs.last_hidden_state
|
lyric_hidden_states = lyric_encoder_outputs.last_hidden_state
|
||||||
timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(
|
timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(reference_latents, refer_audio_order_mask)
|
||||||
refer_audio_acoustic_hidden_states_packed,
|
|
||||||
refer_audio_order_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_hidden_states, encoder_attention_mask = pack_sequences(
|
encoder_hidden_states, encoder_attention_mask = pack_sequences(
|
||||||
lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask
|
lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ class TimestepEmbedding(nn.Module):
|
|||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
time_embed_dim: int,
|
time_embed_dim: int,
|
||||||
scale: float = 1000,
|
scale: float = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -711,7 +711,7 @@ class AceStepDiTModel(nn.Module):
|
|||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
encoder_attention_mask: torch.Tensor,
|
encoder_attention_mask: torch.Tensor,
|
||||||
context_latents: torch.Tensor,
|
context_latents: torch.Tensor,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = False,
|
||||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -2,17 +2,6 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
class AceStepTextEncoder(torch.nn.Module):
|
class AceStepTextEncoder(torch.nn.Module):
|
||||||
"""
|
|
||||||
Text encoder for ACE-Step using Qwen3-Embedding-0.6B.
|
|
||||||
|
|
||||||
Converts text/lyric tokens to hidden state embeddings that are
|
|
||||||
further processed by the ACE-Step ConditionEncoder.
|
|
||||||
|
|
||||||
Wraps a Qwen3Model transformers model. Config is manually
|
|
||||||
constructed, and model weights are loaded via DiffSynth's
|
|
||||||
standard mechanism from safetensors files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
@@ -49,8 +38,6 @@ class AceStepTextEncoder(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.model = Qwen3Model(config)
|
self.model = Qwen3Model(config)
|
||||||
self.config = config
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -58,23 +45,9 @@ class AceStepTextEncoder(torch.nn.Module):
|
|||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Encode text/lyric tokens to hidden states.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: [B, T] token IDs
|
|
||||||
attention_mask: [B, T] attention mask
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
last_hidden_state: [B, T, hidden_size]
|
|
||||||
"""
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
return outputs.last_hidden_state
|
return outputs.last_hidden_state
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
self.model.to(*args, **kwargs)
|
|
||||||
return self
|
|
||||||
|
|||||||
@@ -226,6 +226,7 @@ class AceStepVAE(nn.Module):
|
|||||||
upsampling_ratios=upsampling_ratios,
|
upsampling_ratios=upsampling_ratios,
|
||||||
channel_multiples=channel_multiples,
|
channel_multiples=channel_multiples,
|
||||||
)
|
)
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
|
||||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Audio waveform [B, audio_channels, T] → latent [B, encoder_hidden_size, T']."""
|
"""Audio waveform [B, audio_channels, T] → latent [B, encoder_hidden_size, T']."""
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ ACE-Step Pipeline for DiffSynth-Studio.
|
|||||||
|
|
||||||
Text-to-Music generation pipeline using ACE-Step 1.5 model.
|
Text-to-Music generation pipeline using ACE-Step 1.5 model.
|
||||||
"""
|
"""
|
||||||
|
import re
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional, Dict, Any, List, Tuple
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..core.device.npu_compatible_device import get_device_type
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
@@ -16,6 +17,7 @@ from ..models.ace_step_dit import AceStepDiTModel
|
|||||||
from ..models.ace_step_conditioner import AceStepConditionEncoder
|
from ..models.ace_step_conditioner import AceStepConditionEncoder
|
||||||
from ..models.ace_step_text_encoder import AceStepTextEncoder
|
from ..models.ace_step_text_encoder import AceStepTextEncoder
|
||||||
from ..models.ace_step_vae import AceStepVAE
|
from ..models.ace_step_vae import AceStepVAE
|
||||||
|
from ..models.ace_step_tokenizer import AceStepTokenizer
|
||||||
|
|
||||||
|
|
||||||
class AceStepPipeline(BasePipeline):
|
class AceStepPipeline(BasePipeline):
|
||||||
@@ -32,29 +34,18 @@ class AceStepPipeline(BasePipeline):
|
|||||||
self.text_encoder: AceStepTextEncoder = None
|
self.text_encoder: AceStepTextEncoder = None
|
||||||
self.conditioner: AceStepConditionEncoder = None
|
self.conditioner: AceStepConditionEncoder = None
|
||||||
self.dit: AceStepDiTModel = None
|
self.dit: AceStepDiTModel = None
|
||||||
self.vae = None # AutoencoderOobleck (diffusers) or AceStepVAE
|
self.vae: AceStepVAE = None
|
||||||
|
self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer)
|
||||||
|
|
||||||
# Unit chain order — 7 units total
|
|
||||||
#
|
|
||||||
# 1. ShapeChecker: duration → seq_len
|
|
||||||
# 2. PromptEmbedder: prompt/lyrics → text/lyric embeddings (shared for CFG)
|
|
||||||
# 3. SilenceLatentInitializer: seq_len → src_latents + chunk_masks
|
|
||||||
# 4. ContextLatentBuilder: src_latents + chunk_masks → context_latents (shared, same for CFG+)
|
|
||||||
# 5. ConditionEmbedder: text/lyric → encoder_hidden_states (separate for CFG+/-)
|
|
||||||
# 6. NoiseInitializer: context_latents → noise
|
|
||||||
# 7. InputAudioEmbedder: noise → latents
|
|
||||||
#
|
|
||||||
# ContextLatentBuilder runs before ConditionEmbedder so that
|
|
||||||
# context_latents is available for noise shape computation.
|
|
||||||
self.in_iteration_models = ("dit",)
|
self.in_iteration_models = ("dit",)
|
||||||
self.units = [
|
self.units = [
|
||||||
AceStepUnit_ShapeChecker(),
|
|
||||||
AceStepUnit_PromptEmbedder(),
|
AceStepUnit_PromptEmbedder(),
|
||||||
AceStepUnit_SilenceLatentInitializer(),
|
AceStepUnit_ReferenceAudioEmbedder(),
|
||||||
AceStepUnit_ContextLatentBuilder(),
|
|
||||||
AceStepUnit_ConditionEmbedder(),
|
AceStepUnit_ConditionEmbedder(),
|
||||||
|
AceStepUnit_ContextLatentBuilder(),
|
||||||
AceStepUnit_NoiseInitializer(),
|
AceStepUnit_NoiseInitializer(),
|
||||||
AceStepUnit_InputAudioEmbedder(),
|
AceStepUnit_InputAudioEmbedder(),
|
||||||
|
AceStepUnit_AudioCodeDecoder(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_ace_step
|
self.model_fn = model_fn_ace_step
|
||||||
self.compilable_models = ["dit"]
|
self.compilable_models = ["dit"]
|
||||||
@@ -66,7 +57,8 @@ class AceStepPipeline(BasePipeline):
|
|||||||
torch_dtype: torch.dtype = torch.bfloat16,
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
device: str = get_device_type(),
|
device: str = get_device_type(),
|
||||||
model_configs: list[ModelConfig] = [],
|
model_configs: list[ModelConfig] = [],
|
||||||
text_tokenizer_config: ModelConfig = None,
|
text_tokenizer_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||||
|
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||||
vram_limit: float = None,
|
vram_limit: float = None,
|
||||||
):
|
):
|
||||||
"""Load pipeline from pretrained checkpoints."""
|
"""Load pipeline from pretrained checkpoints."""
|
||||||
@@ -77,11 +69,15 @@ class AceStepPipeline(BasePipeline):
|
|||||||
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
|
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
|
||||||
pipe.dit = model_pool.fetch_model("ace_step_dit")
|
pipe.dit = model_pool.fetch_model("ace_step_dit")
|
||||||
pipe.vae = model_pool.fetch_model("ace_step_vae")
|
pipe.vae = model_pool.fetch_model("ace_step_vae")
|
||||||
|
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
|
||||||
|
|
||||||
if text_tokenizer_config is not None:
|
if text_tokenizer_config is not None:
|
||||||
text_tokenizer_config.download_if_necessary()
|
text_tokenizer_config.download_if_necessary()
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
|
pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
|
||||||
|
if silence_latent_config is not None:
|
||||||
|
silence_latent_config.download_if_necessary()
|
||||||
|
pipe.silence_latent = torch.load(silence_latent_config.path, weights_only=True).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
# VRAM Management
|
# VRAM Management
|
||||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
@@ -97,9 +93,19 @@ class AceStepPipeline(BasePipeline):
|
|||||||
# Lyrics
|
# Lyrics
|
||||||
lyrics: str = "",
|
lyrics: str = "",
|
||||||
# Reference audio (optional, for timbre conditioning)
|
# Reference audio (optional, for timbre conditioning)
|
||||||
reference_audio = None,
|
reference_audios: List[torch.Tensor] = None,
|
||||||
|
# Src audio
|
||||||
|
src_audio: torch.Tensor = None,
|
||||||
|
denoising_strength: float = 1.0,
|
||||||
|
# Simple Mode: LLM-generated audio codes (optional)
|
||||||
|
audio_codes: str = None,
|
||||||
# Shape
|
# Shape
|
||||||
duration: float = 60.0,
|
duration: int = 60,
|
||||||
|
# Audio Meta
|
||||||
|
bpm: Optional[int] = 100,
|
||||||
|
keyscale: Optional[str] = "B minor",
|
||||||
|
timesignature: Optional[str] = "4",
|
||||||
|
vocal_language: Optional[str] = 'zh',
|
||||||
# Randomness
|
# Randomness
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
rand_device: str = "cpu",
|
rand_device: str = "cpu",
|
||||||
@@ -111,11 +117,7 @@ class AceStepPipeline(BasePipeline):
|
|||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
):
|
):
|
||||||
# 1. Scheduler
|
# 1. Scheduler
|
||||||
self.scheduler.set_timesteps(
|
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
denoising_strength=1.0,
|
|
||||||
shift=shift,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. 三字典输入
|
# 2. 三字典输入
|
||||||
inputs_posi = {"prompt": prompt}
|
inputs_posi = {"prompt": prompt}
|
||||||
@@ -123,8 +125,11 @@ class AceStepPipeline(BasePipeline):
|
|||||||
inputs_shared = {
|
inputs_shared = {
|
||||||
"cfg_scale": cfg_scale,
|
"cfg_scale": cfg_scale,
|
||||||
"lyrics": lyrics,
|
"lyrics": lyrics,
|
||||||
"reference_audio": reference_audio,
|
"reference_audios": reference_audios,
|
||||||
|
"src_audio": src_audio,
|
||||||
|
"audio_codes": audio_codes,
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
|
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
|
||||||
"seed": seed,
|
"seed": seed,
|
||||||
"rand_device": rand_device,
|
"rand_device": rand_device,
|
||||||
"num_inference_steps": num_inference_steps,
|
"num_inference_steps": num_inference_steps,
|
||||||
@@ -159,6 +164,10 @@ class AceStepPipeline(BasePipeline):
|
|||||||
# VAE returns OobleckDecoderOutput with .sample attribute
|
# VAE returns OobleckDecoderOutput with .sample attribute
|
||||||
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
|
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
|
||||||
audio = self.output_audio_format_check(audio_output)
|
audio = self.output_audio_format_check(audio_output)
|
||||||
|
|
||||||
|
# Peak normalization to match target library behavior
|
||||||
|
audio = self.normalize_audio(audio, target_db=-1.0)
|
||||||
|
|
||||||
self.load_models_to_device([])
|
self.load_models_to_device([])
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
@@ -172,294 +181,303 @@ class AceStepPipeline(BasePipeline):
|
|||||||
audio_output = audio_output.squeeze(0)
|
audio_output = audio_output.squeeze(0)
|
||||||
return audio_output.float()
|
return audio_output.float()
|
||||||
|
|
||||||
|
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
|
||||||
|
"""Apply peak normalization to audio data, matching target library behavior.
|
||||||
|
|
||||||
class AceStepUnit_ShapeChecker(PipelineUnit):
|
Target library reference: `acestep/audio_utils.py:normalize_audio()`
|
||||||
"""Check and compute sequence length from duration."""
|
peak = max(abs(audio))
|
||||||
def __init__(self):
|
gain = 10^(target_db/20) / peak
|
||||||
super().__init__(
|
audio = audio * gain
|
||||||
input_params=("duration",),
|
|
||||||
output_params=("duration", "seq_len"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, pipe, duration):
|
Args:
|
||||||
# ACE-Step: 25 Hz latent rate
|
audio: Audio tensor [C, T]
|
||||||
seq_len = int(duration * 25)
|
target_db: Target peak level in dB (default: -1.0)
|
||||||
return {"duration": duration, "seq_len": seq_len}
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized audio tensor
|
||||||
|
"""
|
||||||
|
peak = torch.max(torch.abs(audio))
|
||||||
|
if peak < 1e-6:
|
||||||
|
return audio
|
||||||
|
target_amp = 10 ** (target_db / 20.0)
|
||||||
|
gain = target_amp / peak
|
||||||
|
return audio * gain
|
||||||
|
|
||||||
class AceStepUnit_PromptEmbedder(PipelineUnit):
|
class AceStepUnit_PromptEmbedder(PipelineUnit):
|
||||||
"""Encode prompt and lyrics using Qwen3-Embedding.
|
SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
|
||||||
|
INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
|
||||||
|
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
|
||||||
|
|
||||||
Uses seperate_cfg=True to read prompt from inputs_posi (not inputs_shared).
|
|
||||||
The negative condition uses null_condition_emb (handled by ConditionEmbedder),
|
|
||||||
so negative text encoding is not needed here.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
seperate_cfg=True,
|
seperate_cfg=True,
|
||||||
input_params_posi={"prompt": "prompt"},
|
input_params_posi={"prompt": "prompt"},
|
||||||
input_params_nega={},
|
input_params_nega={"prompt": "prompt"},
|
||||||
input_params=("lyrics",),
|
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language"),
|
||||||
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
|
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
|
||||||
onload_model_names=("text_encoder",)
|
onload_model_names=("text_encoder",)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _encode_text(self, pipe, text):
|
def _encode_text(self, pipe, text, max_length=256):
|
||||||
"""Encode text using Qwen3-Embedding → [B, T, 1024]."""
|
"""Encode text using Qwen3-Embedding → [B, T, 1024]."""
|
||||||
if pipe.tokenizer is None:
|
|
||||||
return None, None
|
|
||||||
text_inputs = pipe.tokenizer(
|
text_inputs = pipe.tokenizer(
|
||||||
text,
|
text,
|
||||||
padding="max_length",
|
max_length=max_length,
|
||||||
max_length=512,
|
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
input_ids = text_inputs.input_ids.to(pipe.device)
|
input_ids = text_inputs.input_ids.to(pipe.device)
|
||||||
attention_mask = text_inputs.attention_mask.to(pipe.device)
|
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
|
||||||
hidden_states = pipe.text_encoder(input_ids, attention_mask)
|
hidden_states = pipe.text_encoder(input_ids, attention_mask)
|
||||||
return hidden_states, attention_mask
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
def process(self, pipe, prompt, lyrics, negative_prompt=None):
|
def _encode_lyrics(self, pipe, lyric_text, max_length=2048):
|
||||||
|
text_inputs = pipe.tokenizer(
|
||||||
|
lyric_text,
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
input_ids = text_inputs.input_ids.to(pipe.device)
|
||||||
|
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
|
||||||
|
hidden_states = pipe.text_encoder.model.embed_tokens(input_ids)
|
||||||
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
|
||||||
|
bpm = meta_dict.get("bpm", "N/A")
|
||||||
|
timesignature = meta_dict.get("timesignature", "N/A")
|
||||||
|
keyscale = meta_dict.get("keyscale", "N/A")
|
||||||
|
duration = meta_dict.get("duration", 30)
|
||||||
|
duration = f"{int(duration)} seconds"
|
||||||
|
return (
|
||||||
|
f"- bpm: {bpm}\n"
|
||||||
|
f"- timesignature: {timesignature}\n"
|
||||||
|
f"- keyscale: {keyscale}\n"
|
||||||
|
f"- duration: {duration}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe, prompt, lyrics, duration, bpm, keyscale, timesignature, vocal_language):
|
||||||
pipe.load_models_to_device(['text_encoder'])
|
pipe.load_models_to_device(['text_encoder'])
|
||||||
|
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
|
||||||
|
prompt = self.SFT_GEN_PROMPT.format(self.INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
|
||||||
|
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
|
||||||
|
|
||||||
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt)
|
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
|
||||||
|
lyric_hidden_states, lyric_attention_mask = self._encode_lyrics(pipe, lyric_text, max_length=2048)
|
||||||
|
|
||||||
# Lyrics encoding — use empty string if not provided
|
# TODO: remove this
|
||||||
lyric_text = lyrics if lyrics else ""
|
newtext = prompt + "\n\n" + lyric_text
|
||||||
lyric_hidden_states, lyric_attention_mask = self._encode_text(pipe, lyric_text)
|
return {
|
||||||
|
"text_hidden_states": text_hidden_states,
|
||||||
if text_hidden_states is not None and lyric_hidden_states is not None:
|
"text_attention_mask": text_attention_mask,
|
||||||
return {
|
"lyric_hidden_states": lyric_hidden_states,
|
||||||
"text_hidden_states": text_hidden_states,
|
"lyric_attention_mask": lyric_attention_mask,
|
||||||
"text_attention_mask": text_attention_mask,
|
}
|
||||||
"lyric_hidden_states": lyric_hidden_states,
|
|
||||||
"lyric_attention_mask": lyric_attention_mask,
|
|
||||||
}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class AceStepUnit_SilenceLatentInitializer(PipelineUnit):
|
class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
||||||
"""Generate silence latent (all zeros) and chunk_masks for text2music.
|
|
||||||
|
|
||||||
Target library reference: `prepare_condition()` line 1698-1699:
|
|
||||||
context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
|
|
||||||
|
|
||||||
For text2music mode:
|
|
||||||
- src_latents = zeros [B, T, 64] (VAE latent dimension)
|
|
||||||
- chunk_masks = ones [B, T, 64] (full visibility mask for text2music)
|
|
||||||
- context_latents = [B, T, 128] (concat of src_latents + chunk_masks)
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("seq_len",),
|
input_params=("reference_audios",),
|
||||||
output_params=("silence_latent", "src_latents", "chunk_masks"),
|
output_params=("reference_latents", "refer_audio_order_mask"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe, seq_len):
|
def process(self, pipe, reference_audios):
|
||||||
# silence_latent shape: [B, T, 64] — 64 is the VAE latent dimension
|
pipe.load_models_to_device(['vae'])
|
||||||
silence_latent = torch.zeros(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype)
|
if reference_audios is not None and len(reference_audios) > 0:
|
||||||
# For text2music: src_latents = silence_latent
|
# TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask
|
||||||
src_latents = silence_latent.clone()
|
pass
|
||||||
|
else:
|
||||||
|
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
|
||||||
|
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
|
||||||
|
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
|
||||||
|
|
||||||
# chunk_masks: [B, T, 64] of ones (same shape as src_latents)
|
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# In text2music mode (is_covers=0), chunk_masks are all 1.0
|
"""Infer packed reference-audio latents and order mask."""
|
||||||
# This matches the target library's behavior at line 1699
|
refer_audio_order_mask = []
|
||||||
chunk_masks = torch.ones(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype)
|
refer_audio_latents = []
|
||||||
|
|
||||||
return {"silence_latent": silence_latent, "src_latents": src_latents, "chunk_masks": chunk_masks}
|
def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not isinstance(a, torch.Tensor):
|
||||||
|
raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}")
|
||||||
|
if a.dim() == 3 and a.shape[0] == 1:
|
||||||
|
a = a.squeeze(0)
|
||||||
|
if a.dim() == 1:
|
||||||
|
a = a.unsqueeze(0)
|
||||||
|
if a.dim() != 2:
|
||||||
|
raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}")
|
||||||
|
if a.shape[0] == 1:
|
||||||
|
a = torch.cat([a, a], dim=0)
|
||||||
|
return a[:2]
|
||||||
|
|
||||||
|
def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor:
|
||||||
|
if z.dim() == 4 and z.shape[0] == 1:
|
||||||
|
z = z.squeeze(0)
|
||||||
|
if z.dim() == 2:
|
||||||
|
z = z.unsqueeze(0)
|
||||||
|
return z
|
||||||
|
|
||||||
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
refer_encode_cache: Dict[int, torch.Tensor] = {}
|
||||||
"""Build context_latents from src_latents and chunk_masks.
|
for batch_idx, refer_audios in enumerate(refer_audioss):
|
||||||
|
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
|
||||||
|
refer_audio_latent = _ensure_latent_3d(pipe.silence_latent[:, :750, :])
|
||||||
|
refer_audio_latents.append(refer_audio_latent)
|
||||||
|
refer_audio_order_mask.append(batch_idx)
|
||||||
|
else:
|
||||||
|
# TODO: check
|
||||||
|
for refer_audio in refer_audios:
|
||||||
|
cache_key = refer_audio.data_ptr()
|
||||||
|
if cache_key in refer_encode_cache:
|
||||||
|
refer_audio_latent = refer_encode_cache[cache_key].clone()
|
||||||
|
else:
|
||||||
|
refer_audio = _normalize_audio_2d(refer_audio)
|
||||||
|
refer_audio_latent = pipe.vae.encode(refer_audio)
|
||||||
|
refer_audio_latent = refer_audio_latent.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
if refer_audio_latent.dim() == 2:
|
||||||
|
refer_audio_latent = refer_audio_latent.unsqueeze(0)
|
||||||
|
refer_audio_latent = _ensure_latent_3d(refer_audio_latent.transpose(1, 2))
|
||||||
|
refer_encode_cache[cache_key] = refer_audio_latent
|
||||||
|
refer_audio_latents.append(refer_audio_latent)
|
||||||
|
refer_audio_order_mask.append(batch_idx)
|
||||||
|
|
||||||
Target library reference: `prepare_condition()` line 1699:
|
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
|
||||||
context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
|
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
|
||||||
|
return refer_audio_latents, refer_audio_order_mask
|
||||||
context_latents is the SAME for positive and negative CFG paths
|
|
||||||
(it comes from src_latents + chunk_masks, not from text encoding).
|
|
||||||
So this is a普通模式 Unit — outputs go to inputs_shared.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("src_latents", "chunk_masks"),
|
|
||||||
output_params=("context_latents", "attention_mask"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, pipe, src_latents, chunk_masks):
|
|
||||||
# context_latents: cat([src_latents, chunk_masks], dim=-1) → [B, T, 128]
|
|
||||||
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
|
||||||
|
|
||||||
# attention_mask for the DiT: ones [B, T]
|
|
||||||
# The target library uses this for cross-attention with context_latents
|
|
||||||
attention_mask = torch.ones(src_latents.shape[0], src_latents.shape[1],
|
|
||||||
device=pipe.device, dtype=pipe.torch_dtype)
|
|
||||||
|
|
||||||
return {"context_latents": context_latents, "attention_mask": attention_mask}
|
|
||||||
|
|
||||||
|
|
||||||
class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||||
"""Generate encoder_hidden_states via ACEStepConditioner.
|
|
||||||
|
|
||||||
Target library reference: `prepare_condition()` line 1674-1681:
|
|
||||||
encoder_hidden_states, encoder_attention_mask = self.encoder(...)
|
|
||||||
|
|
||||||
Uses seperate_cfg mode:
|
|
||||||
- Positive: encode with full condition (text + lyrics + reference audio)
|
|
||||||
- Negative: replace text with null_condition_emb, keep lyrics/timbre same
|
|
||||||
|
|
||||||
context_latents is handled by ContextLatentBuilder (普通模式), not here.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
seperate_cfg=True,
|
seperate_cfg=True,
|
||||||
input_params_posi={
|
input_params_posi={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
|
||||||
"text_hidden_states": "text_hidden_states",
|
input_params_nega={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
|
||||||
"text_attention_mask": "text_attention_mask",
|
input_params=("reference_latents", "refer_audio_order_mask"),
|
||||||
"lyric_hidden_states": "lyric_hidden_states",
|
output_params=("encoder_hidden_states", "encoder_attention_mask"),
|
||||||
"lyric_attention_mask": "lyric_attention_mask",
|
onload_model_names=("conditioner",),
|
||||||
"reference_audio": "reference_audio",
|
|
||||||
"refer_audio_order_mask": "refer_audio_order_mask",
|
|
||||||
},
|
|
||||||
input_params_nega={},
|
|
||||||
input_params=("cfg_scale",),
|
|
||||||
output_params=(
|
|
||||||
"encoder_hidden_states", "encoder_attention_mask",
|
|
||||||
"negative_encoder_hidden_states", "negative_encoder_attention_mask",
|
|
||||||
),
|
|
||||||
onload_model_names=("conditioner",)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_condition(self, pipe, text_hidden_states, text_attention_mask,
|
def process(self, pipe, text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, reference_latents, refer_audio_order_mask):
|
||||||
lyric_hidden_states, lyric_attention_mask,
|
|
||||||
refer_audio_acoustic_hidden_states_packed=None,
|
|
||||||
refer_audio_order_mask=None):
|
|
||||||
"""Call ACEStepConditioner forward to produce encoder_hidden_states."""
|
|
||||||
pipe.load_models_to_device(['conditioner'])
|
pipe.load_models_to_device(['conditioner'])
|
||||||
|
|
||||||
# Handle reference audio
|
|
||||||
if refer_audio_acoustic_hidden_states_packed is None:
|
|
||||||
# No reference audio: create 2D packed zeros [N=1, d=64]
|
|
||||||
# TimbreEncoder.unpack expects [N, d], not [B, T, d]
|
|
||||||
refer_audio_acoustic_hidden_states_packed = torch.zeros(
|
|
||||||
1, 64, device=pipe.device, dtype=pipe.torch_dtype
|
|
||||||
)
|
|
||||||
refer_audio_order_mask = torch.LongTensor([0]).to(pipe.device)
|
|
||||||
|
|
||||||
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
|
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
|
||||||
text_hidden_states=text_hidden_states,
|
text_hidden_states=text_hidden_states,
|
||||||
text_attention_mask=text_attention_mask,
|
text_attention_mask=text_attention_mask,
|
||||||
lyric_hidden_states=lyric_hidden_states,
|
lyric_hidden_states=lyric_hidden_states,
|
||||||
lyric_attention_mask=lyric_attention_mask,
|
lyric_attention_mask=lyric_attention_mask,
|
||||||
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
|
reference_latents=reference_latents,
|
||||||
refer_audio_order_mask=refer_audio_order_mask,
|
refer_audio_order_mask=refer_audio_order_mask,
|
||||||
)
|
)
|
||||||
|
return {"encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask}
|
||||||
|
|
||||||
return encoder_hidden_states, encoder_attention_mask
|
|
||||||
|
|
||||||
def _prepare_negative_condition(self, pipe, lyric_hidden_states, lyric_attention_mask,
|
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||||
refer_audio_acoustic_hidden_states_packed=None,
|
def __init__(self):
|
||||||
refer_audio_order_mask=None):
|
super().__init__(
|
||||||
"""Generate negative condition using null_condition_emb."""
|
input_params=("duration", "src_audio"),
|
||||||
if pipe.conditioner is None or not hasattr(pipe.conditioner, 'null_condition_emb'):
|
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
|
||||||
return None, None
|
|
||||||
|
|
||||||
null_emb = pipe.conditioner.null_condition_emb # [1, 1, hidden_size]
|
|
||||||
bsz = 1
|
|
||||||
if lyric_hidden_states is not None:
|
|
||||||
bsz = lyric_hidden_states.shape[0]
|
|
||||||
null_hidden_states = null_emb.expand(bsz, -1, -1)
|
|
||||||
null_attn_mask = torch.ones(bsz, 1, device=pipe.device, dtype=pipe.torch_dtype)
|
|
||||||
|
|
||||||
# For negative: use null_condition_emb as text, keep lyrics and timbre
|
|
||||||
neg_encoder_hidden_states, neg_encoder_attention_mask = pipe.conditioner(
|
|
||||||
text_hidden_states=null_hidden_states,
|
|
||||||
text_attention_mask=null_attn_mask,
|
|
||||||
lyric_hidden_states=lyric_hidden_states,
|
|
||||||
lyric_attention_mask=lyric_attention_mask,
|
|
||||||
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
|
|
||||||
refer_audio_order_mask=refer_audio_order_mask,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return neg_encoder_hidden_states, neg_encoder_attention_mask
|
def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
|
||||||
|
available = pipe.silence_latent.shape[1]
|
||||||
|
if length <= available:
|
||||||
|
return pipe.silence_latent[0, :length, :]
|
||||||
|
repeats = (length + available - 1) // available
|
||||||
|
tiled = pipe.silence_latent[0].repeat(repeats, 1)
|
||||||
|
return tiled[:length, :]
|
||||||
|
|
||||||
def process(self, pipe, text_hidden_states, text_attention_mask,
|
def process(self, pipe, duration, src_audio):
|
||||||
lyric_hidden_states, lyric_attention_mask,
|
if src_audio is not None:
|
||||||
reference_audio=None, refer_audio_order_mask=None,
|
raise NotImplementedError("Src audio conditioning is not implemented yet. Please set src_audio to None.")
|
||||||
negative_prompt=None, cfg_scale=1.0):
|
else:
|
||||||
|
max_latent_length = duration * pipe.sample_rate // 1920
|
||||||
# Positive condition
|
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
||||||
pos_enc_hs, pos_enc_mask = self._prepare_condition(
|
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
|
||||||
pipe, text_hidden_states, text_attention_mask,
|
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
|
||||||
lyric_hidden_states, lyric_attention_mask,
|
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
||||||
None, refer_audio_order_mask,
|
return {"context_latents": context_latents, "attention_mask": attention_mask}
|
||||||
)
|
|
||||||
|
|
||||||
# Negative condition: only needed when CFG is active (cfg_scale > 1.0)
|
|
||||||
# For cfg_scale=1.0 (turbo), skip to avoid null_condition_emb dimension mismatch
|
|
||||||
result = {
|
|
||||||
"encoder_hidden_states": pos_enc_hs,
|
|
||||||
"encoder_attention_mask": pos_enc_mask,
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg_scale > 1.0:
|
|
||||||
neg_enc_hs, neg_enc_mask = self._prepare_negative_condition(
|
|
||||||
pipe, lyric_hidden_states, lyric_attention_mask,
|
|
||||||
None, refer_audio_order_mask,
|
|
||||||
)
|
|
||||||
if neg_enc_hs is not None:
|
|
||||||
result["negative_encoder_hidden_states"] = neg_enc_hs
|
|
||||||
result["negative_encoder_attention_mask"] = neg_enc_mask
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class AceStepUnit_NoiseInitializer(PipelineUnit):
|
class AceStepUnit_NoiseInitializer(PipelineUnit):
|
||||||
"""Generate initial noise tensor.
|
|
||||||
|
|
||||||
Target library reference: `prepare_noise()` line 1781-1818:
|
|
||||||
src_latents_shape = (bsz, context_latents.shape[1], context_latents.shape[-1] // 2)
|
|
||||||
|
|
||||||
Noise shape = [B, T, context_latents.shape[-1] // 2] = [B, T, 128 // 2] = [B, T, 64]
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("seed", "seq_len", "rand_device", "context_latents"),
|
input_params=("context_latents", "seed", "rand_device"),
|
||||||
output_params=("noise",),
|
output_params=("noise",),
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe, seed, seq_len, rand_device, context_latents):
|
def process(self, pipe, context_latents, seed, rand_device):
|
||||||
# Noise shape: [B, T, context_latents.shape[-1] // 2]
|
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
|
||||||
# context_latents = [B, T, 128] → noise = [B, T, 64]
|
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
# This matches the target library's prepare_noise() at line 1796
|
|
||||||
noise_shape = (context_latents.shape[0], context_latents.shape[1],
|
|
||||||
context_latents.shape[-1] // 2)
|
|
||||||
noise = pipe.generate_noise(
|
|
||||||
noise_shape,
|
|
||||||
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
|
|
||||||
)
|
|
||||||
return {"noise": noise}
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
||||||
"""Set up latents for denoise loop.
|
|
||||||
|
|
||||||
For text2music (no input audio): latents = noise, input_latents = None.
|
|
||||||
|
|
||||||
Target library reference: `generate_audio()` line 1972:
|
|
||||||
xt = noise (when cover_noise_strength == 0)
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("noise",),
|
input_params=("noise", "input_audio"),
|
||||||
output_params=("latents", "input_latents"),
|
output_params=("latents", "input_latents"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe, noise):
|
def process(self, pipe, noise, input_audio):
|
||||||
# For text2music: start from pure noise
|
if input_audio is None:
|
||||||
|
return {"latents": noise}
|
||||||
|
# TODO: support for train
|
||||||
return {"latents": noise, "input_latents": None}
|
return {"latents": noise, "input_latents": None}
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("audio_codes", "seq_len", "silence_latent"),
|
||||||
|
output_params=("lm_hints_25Hz",),
|
||||||
|
onload_model_names=("tokenizer_model",),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_audio_code_string(code_str: str) -> list:
|
||||||
|
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
|
||||||
|
if not code_str:
|
||||||
|
return []
|
||||||
|
codes = []
|
||||||
|
max_audio_code = 63999
|
||||||
|
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
|
||||||
|
code_value = int(x)
|
||||||
|
codes.append(max(0, min(code_value, max_audio_code)))
|
||||||
|
return codes
|
||||||
|
|
||||||
|
def process(self, pipe, audio_codes, seq_len, silence_latent):
|
||||||
|
if audio_codes is None or not audio_codes.strip():
|
||||||
|
return {"lm_hints_25Hz": None}
|
||||||
|
|
||||||
|
code_ids = self._parse_audio_code_string(audio_codes)
|
||||||
|
if len(code_ids) == 0:
|
||||||
|
return {"lm_hints_25Hz": None}
|
||||||
|
|
||||||
|
pipe.load_models_to_device(["tokenizer_model"])
|
||||||
|
|
||||||
|
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
||||||
|
detokenizer = pipe.tokenizer_model.detokenizer
|
||||||
|
|
||||||
|
indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
|
||||||
|
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
|
||||||
|
|
||||||
|
quantized = quantizer.get_output_from_indices(indices) # [1, N, 2048]
|
||||||
|
if quantized.dtype != pipe.torch_dtype:
|
||||||
|
quantized = quantized.to(pipe.torch_dtype)
|
||||||
|
|
||||||
|
lm_hints = detokenizer(quantized) # [1, N*5, 64]
|
||||||
|
|
||||||
|
# Pad or truncate to seq_len
|
||||||
|
current_len = lm_hints.shape[1]
|
||||||
|
if current_len < seq_len:
|
||||||
|
pad_len = seq_len - current_len
|
||||||
|
pad = silence_latent[:, :pad_len, :]
|
||||||
|
lm_hints = torch.cat([lm_hints, pad], dim=1)
|
||||||
|
elif current_len > seq_len:
|
||||||
|
lm_hints = lm_hints[:, :seq_len, :]
|
||||||
|
|
||||||
|
return {"lm_hints_25Hz": lm_hints}
|
||||||
|
|
||||||
|
|
||||||
def model_fn_ace_step(
|
def model_fn_ace_step(
|
||||||
dit: AceStepDiTModel,
|
dit: AceStepDiTModel,
|
||||||
latents=None,
|
latents=None,
|
||||||
@@ -468,49 +486,11 @@ def model_fn_ace_step(
|
|||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
context_latents=None,
|
context_latents=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
past_key_values=None,
|
use_gradient_checkpointing=False,
|
||||||
negative_encoder_hidden_states=None,
|
use_gradient_checkpointing_offload=False,
|
||||||
negative_encoder_attention_mask=None,
|
|
||||||
negative_context_latents=None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Model function for ACE-Step DiT forward.
|
timestep = timestep.unsqueeze(0)
|
||||||
|
|
||||||
Timestep is already in [0, 1] range — no scaling needed.
|
|
||||||
|
|
||||||
Target library reference: `generate_audio()` line 2009-2020:
|
|
||||||
decoder_outputs = self.decoder(
|
|
||||||
hidden_states=x, timestep=t_curr_tensor, timestep_r=t_curr_tensor,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
context_latents=context_latents,
|
|
||||||
use_cache=True, past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dit: AceStepDiTModel
|
|
||||||
latents: [B, T, 64] noise/latent tensor (same shape as src_latents)
|
|
||||||
timestep: scalar tensor in [0, 1]
|
|
||||||
encoder_hidden_states: [B, T_text, 2048] condition from Conditioner
|
|
||||||
(positive or negative depending on CFG pass — the cfg_guided_model_fn
|
|
||||||
passes inputs_posi for positive, inputs_nega for negative)
|
|
||||||
encoder_attention_mask: [B, T_text]
|
|
||||||
context_latents: [B, T, 128] = cat([src_latents, chunk_masks], dim=-1)
|
|
||||||
(same for both CFG+/- paths in text2music mode)
|
|
||||||
attention_mask: [B, T] ones mask for DiT
|
|
||||||
past_key_values: EncoderDecoderCache for KV caching
|
|
||||||
|
|
||||||
The DiT internally concatenates: cat([context_latents, latents], dim=-1) = [B, T, 192]
|
|
||||||
as the actual input (128 + 64 = 192 channels).
|
|
||||||
"""
|
|
||||||
# ACE-Step uses timestep directly in [0, 1] range — no /1000 scaling
|
|
||||||
timestep = timestep.squeeze()
|
|
||||||
|
|
||||||
# Expand timestep to match batch size
|
|
||||||
bsz = latents.shape[0]
|
|
||||||
timestep = timestep.expand(bsz)
|
|
||||||
|
|
||||||
decoder_outputs = dit(
|
decoder_outputs = dit(
|
||||||
hidden_states=latents,
|
hidden_states=latents,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
@@ -519,9 +499,5 @@ def model_fn_ace_step(
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
context_latents=context_latents,
|
context_latents=context_latents,
|
||||||
use_cache=True,
|
)[0]
|
||||||
past_key_values=past_key_values,
|
return decoder_outputs
|
||||||
)
|
|
||||||
|
|
||||||
# Return velocity prediction (first element of decoder_outputs)
|
|
||||||
return decoder_outputs[0]
|
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend
|
|||||||
"""
|
"""
|
||||||
if waveform.dim() == 3:
|
if waveform.dim() == 3:
|
||||||
waveform = waveform[0]
|
waveform = waveform[0]
|
||||||
|
waveform.cpu()
|
||||||
|
|
||||||
if backend == "torchcodec":
|
if backend == "torchcodec":
|
||||||
from torchcodec.encoders import AudioEncoder
|
from torchcodec.encoders import AudioEncoder
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
Ace-Step 1.5 — Text-to-Music with Simple Mode (LLM expansion).
|
Ace-Step 1.5 — Text-to-Music with Simple Mode (LLM expansion).
|
||||||
|
|
||||||
Uses the ACE-Step LLM to expand a simple description into structured
|
Uses the ACE-Step LLM to expand a simple description into structured
|
||||||
parameters (caption, lyrics, bpm, keyscale, etc.), then feeds them
|
parameters (caption, lyrics, bpm, keyscale, etc.) AND audio codes,
|
||||||
to the DiffSynth Pipeline.
|
then feeds them to the DiffSynth Pipeline.
|
||||||
|
|
||||||
The LLM expansion uses the target library's LLMHandler. If vLLM is
|
The LLM expansion uses the target library's LLMHandler. If vLLM is
|
||||||
not available, it falls back to using pre-structured parameters.
|
not available, it falls back to using pre-structured parameters.
|
||||||
@@ -47,11 +47,14 @@ def try_load_llm_handler(checkpoint_dir: str, lm_model_path: str = "acestep-5Hz-
|
|||||||
|
|
||||||
|
|
||||||
def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
|
def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
|
||||||
"""Expand a simple description using LLM Chain-of-Thought."""
|
"""Expand a simple description using LLM Chain-of-Thought.
|
||||||
|
|
||||||
|
Returns (params_dict, audio_codes_string).
|
||||||
|
"""
|
||||||
result = llm_handler.generate_with_stop_condition(
|
result = llm_handler.generate_with_stop_condition(
|
||||||
caption=description,
|
caption=description,
|
||||||
lyrics="",
|
lyrics="",
|
||||||
infer_type="dit", # metadata only
|
infer_type="dit", # metadata + audio codes
|
||||||
temperature=0.85,
|
temperature=0.85,
|
||||||
cfg_scale=1.0,
|
cfg_scale=1.0,
|
||||||
use_cot_metas=True,
|
use_cot_metas=True,
|
||||||
@@ -62,7 +65,7 @@ def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
|
|||||||
|
|
||||||
if result.get("success") and result.get("metadata"):
|
if result.get("success") and result.get("metadata"):
|
||||||
meta = result["metadata"]
|
meta = result["metadata"]
|
||||||
return {
|
params = {
|
||||||
"caption": meta.get("caption", description),
|
"caption": meta.get("caption", description),
|
||||||
"lyrics": meta.get("lyrics", ""),
|
"lyrics": meta.get("lyrics", ""),
|
||||||
"bpm": meta.get("bpm", 100),
|
"bpm": meta.get("bpm", 100),
|
||||||
@@ -71,9 +74,11 @@ def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
|
|||||||
"timesignature": meta.get("timesignature", "4"),
|
"timesignature": meta.get("timesignature", "4"),
|
||||||
"duration": meta.get("duration", duration),
|
"duration": meta.get("duration", duration),
|
||||||
}
|
}
|
||||||
|
audio_codes = result.get("audio_codes", "")
|
||||||
|
return params, audio_codes
|
||||||
|
|
||||||
print(f"[Simple Mode] LLM expansion failed: {result.get('error', 'unknown')}")
|
print(f"[Simple Mode] LLM expansion failed: {result.get('error', 'unknown')}")
|
||||||
return None
|
return None, ""
|
||||||
|
|
||||||
|
|
||||||
def fallback_expand(description: str, duration: float = 30.0):
|
def fallback_expand(description: str, duration: float = 30.0):
|
||||||
@@ -87,7 +92,7 @@ def fallback_expand(description: str, duration: float = 30.0):
|
|||||||
"language": "en",
|
"language": "en",
|
||||||
"timesignature": "4",
|
"timesignature": "4",
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
}
|
}, ""
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -114,13 +119,13 @@ def main():
|
|||||||
lm_model_path="acestep-5Hz-lm-1.7B",
|
lm_model_path="acestep-5Hz-lm-1.7B",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Expand parameters
|
# 2. Expand parameters + audio codes
|
||||||
if llm_ok:
|
if llm_ok:
|
||||||
params = expand_with_llm(llm_handler, description, duration=duration)
|
params, audio_codes = expand_with_llm(llm_handler, description, duration=duration)
|
||||||
if params is None:
|
if params is None:
|
||||||
params = fallback_expand(description, duration)
|
params, audio_codes = fallback_expand(description, duration)
|
||||||
else:
|
else:
|
||||||
params = fallback_expand(description, duration)
|
params, audio_codes = fallback_expand(description, duration)
|
||||||
|
|
||||||
print(f"\n[Simple Mode] Parameters:")
|
print(f"\n[Simple Mode] Parameters:")
|
||||||
print(f" Caption: {params['caption'][:100]}...")
|
print(f" Caption: {params['caption'][:100]}...")
|
||||||
@@ -128,6 +133,7 @@ def main():
|
|||||||
print(f" BPM: {params['bpm']}, Keyscale: {params['keyscale']}")
|
print(f" BPM: {params['bpm']}, Keyscale: {params['keyscale']}")
|
||||||
print(f" Language: {params['language']}, Time Sig: {params['timesignature']}")
|
print(f" Language: {params['language']}, Time Sig: {params['timesignature']}")
|
||||||
print(f" Duration: {params['duration']}s")
|
print(f" Duration: {params['duration']}s")
|
||||||
|
print(f" Audio codes: {len(audio_codes)} chars" if audio_codes else " Audio codes: None (fallback)")
|
||||||
|
|
||||||
# 3. Load Pipeline
|
# 3. Load Pipeline
|
||||||
print(f"\n[Pipeline] Loading Ace-Step 1.5 (turbo)...")
|
print(f"\n[Pipeline] Loading Ace-Step 1.5 (turbo)...")
|
||||||
@@ -141,21 +147,17 @@ def main():
|
|||||||
),
|
),
|
||||||
ModelConfig(
|
ModelConfig(
|
||||||
model_id="ACE-Step/Ace-Step1.5",
|
model_id="ACE-Step/Ace-Step1.5",
|
||||||
origin_file_pattern="acestep-v15-turbo/model.safetensors"
|
origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
|
||||||
),
|
),
|
||||||
ModelConfig(
|
ModelConfig(
|
||||||
model_id="ACE-Step/Ace-Step1.5",
|
model_id="ACE-Step/Ace-Step1.5",
|
||||||
origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
|
origin_file_pattern="vae/diffusion_pytorch_model.safetensors"
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(
|
text_tokenizer_config=ModelConfig(
|
||||||
model_id="ACE-Step/Ace-Step1.5",
|
model_id="ACE-Step/Ace-Step1.5",
|
||||||
origin_file_pattern="Qwen3-Embedding-0.6B/"
|
origin_file_pattern="Qwen3-Embedding-0.6B/"
|
||||||
),
|
),
|
||||||
vae_config=ModelConfig(
|
|
||||||
model_id="ACE-Step/Ace-Step1.5",
|
|
||||||
origin_file_pattern="vae/"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Generate
|
# 4. Generate
|
||||||
@@ -164,6 +166,7 @@ def main():
|
|||||||
prompt=params["caption"],
|
prompt=params["caption"],
|
||||||
lyrics=params["lyrics"],
|
lyrics=params["lyrics"],
|
||||||
duration=params["duration"],
|
duration=params["duration"],
|
||||||
|
audio_codes=audio_codes if audio_codes else None,
|
||||||
seed=42,
|
seed=42,
|
||||||
num_inference_steps=8,
|
num_inference_steps=8,
|
||||||
cfg_scale=1.0,
|
cfg_scale=1.0,
|
||||||
|
|||||||
@@ -1,16 +1,6 @@
|
|||||||
"""
|
|
||||||
Ace-Step 1.5 — Text-to-Music (Turbo) inference example.
|
|
||||||
|
|
||||||
Demonstrates the standard text2music pipeline with structured parameters
|
|
||||||
(caption, lyrics, duration, etc.) — no LLM expansion needed.
|
|
||||||
|
|
||||||
For Simple Mode (LLM expands a short description), see:
|
|
||||||
- Ace-Step1.5-SimpleMode.py
|
|
||||||
"""
|
|
||||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||||
import torch
|
import torch
|
||||||
import soundfile as sf
|
from diffsynth.utils.data.audio import save_audio
|
||||||
|
|
||||||
|
|
||||||
pipe = AceStepPipeline.from_pretrained(
|
pipe = AceStepPipeline.from_pretrained(
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
@@ -35,29 +25,21 @@ pipe = AceStepPipeline.from_pretrained(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline."
|
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||||
lyrics = """[Intro - Synth Brass Fanfare]
|
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||||
|
|
||||||
[Verse 1]
|
|
||||||
黑夜里的风吹过耳畔
|
|
||||||
甜蜜时光转瞬即逝
|
|
||||||
脚步飘摇在星光上
|
|
||||||
|
|
||||||
[Chorus]
|
|
||||||
心电感应在震动间
|
|
||||||
拥抱未来勇敢冒险
|
|
||||||
|
|
||||||
[Outro - Instrumental]"""
|
|
||||||
|
|
||||||
audio = pipe(
|
audio = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
lyrics=lyrics,
|
lyrics=lyrics,
|
||||||
duration=30.0,
|
duration=160,
|
||||||
|
bpm=100,
|
||||||
|
keyscale="B minor",
|
||||||
|
timesignature="4",
|
||||||
|
vocal_language="zh",
|
||||||
seed=42,
|
seed=42,
|
||||||
num_inference_steps=8,
|
num_inference_steps=8,
|
||||||
cfg_scale=1.0,
|
cfg_scale=1.0,
|
||||||
shift=3.0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
sf.write("Ace-Step1.5.wav", audio.cpu().numpy(), pipe.sample_rate)
|
save_audio(audio.cpu(), pipe.vae.sampling_rate, "Ace-Step1.5.wav")
|
||||||
print(f"Saved to Ace-Step1.5.wav, shape: {audio.shape}, duration: {audio.shape[-1] / pipe.sample_rate:.1f}s")
|
print(f"Saved to Ace-Step1.5.wav, shape: {audio.shape}, duration: {audio.shape[-1] / pipe.sample_rate:.1f}s")
|
||||||
|
|||||||
Reference in New Issue
Block a user