This commit is contained in:
mi804
2026-04-23 17:31:34 +08:00
parent 394db06d86
commit a80fb84220
14 changed files with 99 additions and 243 deletions

View File

@@ -925,7 +925,7 @@ ace_step_series = [
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_dit",
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
},
# === XL DiT variants (32 layers, hidden_size=2560) ===
# Covers: xl-base, xl-sft, xl-turbo
@@ -934,7 +934,7 @@ ace_step_series = [
"model_hash": "3a28a410c2246f125153ef792d8bc828",
"model_name": "ace_step_dit",
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
"extra_kwargs": {
"hidden_size": 2560,
"intermediate_size": 9728,
@@ -952,7 +952,7 @@ ace_step_series = [
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_conditioner",
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
},
# === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) ===
{
@@ -960,7 +960,7 @@ ace_step_series = [
"model_hash": "3a28a410c2246f125153ef792d8bc828",
"model_name": "ace_step_conditioner",
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
},
# === Qwen3-Embedding (text encoder) ===
{
@@ -968,7 +968,7 @@ ace_step_series = [
"model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
"model_name": "ace_step_text_encoder",
"model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.ace_step_text_encoder_converter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.AceStepTextEncoderStateDictConverter",
},
# === VAE (AutoencoderOobleck CNN) ===
{
@@ -983,7 +983,7 @@ ace_step_series = [
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_tokenizer",
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
},
# === XL Tokenizer (XL models share same tokenizer architecture) ===
{
@@ -991,8 +991,11 @@ ace_step_series = [
"model_hash": "3a28a410c2246f125153ef792d8bc828",
"model_name": "ace_step_tokenizer",
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
MODEL_CONFIGS = (
qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
)

View File

@@ -3,12 +3,10 @@ ACE-Step Pipeline for DiffSynth-Studio.
Text-to-Music generation pipeline using ACE-Step 1.5 model.
"""
import re
import torch
import re, torch
from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
import random
import math
import random, math
import torch.nn.functional as F
from einops import rearrange
@@ -39,7 +37,7 @@ class AceStepPipeline(BasePipeline):
self.conditioner: AceStepConditionEncoder = None
self.dit: AceStepDiTModel = None
self.vae: AceStepVAE = None
self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer)
self.tokenizer_model: AceStepTokenizer = None
self.in_iteration_models = ("dit",)
self.units = [
@@ -65,7 +63,6 @@ class AceStepPipeline(BasePipeline):
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
vram_limit: float = None,
):
"""Load pipeline from pretrained checkpoints."""
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
@@ -102,7 +99,7 @@ class AceStepPipeline(BasePipeline):
reference_audios: List[torch.Tensor] = None,
# Source audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
audio_cover_strength: float = 1.0,
# Audio codes
audio_code_string: Optional[str] = None,
@@ -115,7 +112,7 @@ class AceStepPipeline(BasePipeline):
bpm: Optional[int] = 100,
keyscale: Optional[str] = "B minor",
timesignature: Optional[str] = "4",
vocal_language: Optional[str] = 'unknown',
vocal_language: Optional[str] = "unknown",
# Randomness
seed: int = None,
rand_device: str = "cpu",
@@ -126,10 +123,10 @@ class AceStepPipeline(BasePipeline):
# Progress
progress_bar_cmd=tqdm,
):
# 1. Scheduler
# Scheduler
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
# 2. 三字典输入
# Parameters
inputs_posi = {"prompt": prompt, "positive": True}
inputs_nega = {"positive": False}
inputs_shared = {
@@ -147,13 +144,12 @@ class AceStepPipeline(BasePipeline):
"shift": shift,
}
# 3. Unit 链执行
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# 4. Denoise loop
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
@@ -164,17 +160,17 @@ class AceStepPipeline(BasePipeline):
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id,
)
inputs_shared["latents"] = self.step(self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
inputs_shared["latents"] = self.step(
self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared,
)
# 5. VAE 解码
# Decode
self.load_models_to_device(['vae'])
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
latents = inputs_shared["latents"].transpose(1, 2)
vae_output = self.vae.decode(latents)
# VAE returns OobleckDecoderOutput with .sample attribute
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
audio_output = self.normalize_audio(audio_output, target_db=-1.0)
audio_output = self.normalize_audio(vae_output, target_db=-1.0)
audio = self.output_audio_format_check(audio_output)
self.load_models_to_device([])
return audio
@@ -188,7 +184,9 @@ class AceStepPipeline(BasePipeline):
return audio * gain
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0 or inputs_shared.get("shared_noncover", None) is None:
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0:
return
if inputs_shared.get("shared_noncover", None) is None:
return
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
if progress_id >= cover_steps:
@@ -312,7 +310,10 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
def process(self, pipe, reference_audios):
if reference_audios is not None:
pipe.load_models_to_device(['vae'])
reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios]
reference_audios = [
self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device)
for reference_audio in reference_audios
]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
else:
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
@@ -357,7 +358,6 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
class AceStepUnit_ConditionEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
@@ -378,7 +378,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device,
)
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
@@ -396,7 +398,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
if inputs_shared["cfg_scale"] != 1.0:
inputs_shared["nega_noncover"] = {
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device),
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(
dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device,
),
"encoder_attention_mask": encoder_attention_mask_noncover,
}
return inputs_shared, inputs_posi, inputs_nega
@@ -483,7 +487,7 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
# use audio_cede_string to get src_latents.
pipe.load_models_to_device(self.onload_model_names)
code_ids = self._parse_audio_code_string(audio_code_string)
quantizer = pipe.tokenizer_model.tokenizer.quantizer
quantizer = pipe.tokenizer_model.tokenizer.quantizer.to(device=pipe.device)
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
codes = quantizer.get_codes_from_indices(indices)
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)

View File

@@ -1,38 +1,4 @@
"""
State dict converter for ACE-Step Conditioner model.
The original checkpoint stores all model weights in a single file
(nested in AceStepConditionGenerationModel). The Conditioner weights are
prefixed with 'encoder.'.
This converter extracts only keys starting with 'encoder.' and strips
the prefix to match the standalone AceStepConditionEncoder in DiffSynth.
"""
def ace_step_conditioner_converter(state_dict):
"""
Convert ACE-Step Conditioner checkpoint keys to DiffSynth format.
参数 state_dict 是 DiskMap 类型。
遍历时key 是 key 名state_dict[key] 获取实际值。
Original checkpoint contains all model weights under prefixes:
- decoder.* (DiT)
- encoder.* (Conditioner)
- tokenizer.* (Audio Tokenizer)
- detokenizer.* (Audio Detokenizer)
- null_condition_emb (CFG null embedding)
This extracts only 'encoder.' keys and strips the prefix.
Example mapping:
encoder.lyric_encoder.layers.0.self_attn.q_proj.weight -> lyric_encoder.layers.0.self_attn.q_proj.weight
encoder.attention_pooler.layers.0.self_attn.q_proj.weight -> attention_pooler.layers.0.self_attn.q_proj.weight
encoder.timbre_encoder.layers.0.self_attn.q_proj.weight -> timbre_encoder.layers.0.self_attn.q_proj.weight
encoder.audio_tokenizer.audio_acoustic_proj.weight -> audio_tokenizer.audio_acoustic_proj.weight
encoder.detokenizer.layers.0.self_attn.q_proj.weight -> detokenizer.layers.0.self_attn.q_proj.weight
"""
def AceStepConditionEncoderStateDictConverter(state_dict):
new_state_dict = {}
prefix = "encoder."
@@ -41,7 +7,6 @@ def ace_step_conditioner_converter(state_dict):
new_key = key[len(prefix):]
new_state_dict[new_key] = state_dict[key]
# Extract null_condition_emb from top level (used for CFG negative condition)
if "null_condition_emb" in state_dict:
new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"]

View File

@@ -1,37 +1,4 @@
"""
State dict converter for ACE-Step DiT model.
The original checkpoint stores all model weights in a single file
(nested in AceStepConditionGenerationModel). The DiT weights are
prefixed with 'decoder.'.
This converter extracts only keys starting with 'decoder.' and strips
the prefix to match the standalone AceStepDiTModel in DiffSynth.
"""
def ace_step_dit_converter(state_dict):
"""
Convert ACE-Step DiT checkpoint keys to DiffSynth format.
参数 state_dict 是 DiskMap 类型。
遍历时key 是 key 名state_dict[key] 获取实际值。
Original checkpoint contains all model weights under prefixes:
- decoder.* (DiT)
- encoder.* (Conditioner)
- tokenizer.* (Audio Tokenizer)
- detokenizer.* (Audio Detokenizer)
- null_condition_emb (CFG null embedding)
This extracts only 'decoder.' keys and strips the prefix.
Example mapping:
decoder.layers.0.self_attn.q_proj.weight -> layers.0.self_attn.q_proj.weight
decoder.proj_in.0.linear_1.weight -> proj_in.0.linear_1.weight
decoder.time_embed.linear_1.weight -> time_embed.linear_1.weight
decoder.rotary_emb.inv_freq -> rotary_emb.inv_freq
"""
def AceStepDiTModelStateDictConverter(state_dict):
new_state_dict = {}
prefix = "decoder."

View File

@@ -1,55 +0,0 @@
"""
State dict converter for ACE-Step LLM (Qwen3-based).
The safetensors file stores Qwen3 model weights. Different checkpoints
may have different key formats:
- Qwen3ForCausalLM format: model.embed_tokens.weight, model.layers.0.*
- Qwen3Model format: embed_tokens.weight, layers.0.*
Qwen3ForCausalLM wraps a .model attribute (Qwen3Model), so its
state_dict() has keys:
model.model.embed_tokens.weight
model.model.layers.0.self_attn.q_proj.weight
model.model.norm.weight
model.lm_head.weight (tied to model.model.embed_tokens)
This converter normalizes all keys to the Qwen3ForCausalLM format.
Example mapping:
model.embed_tokens.weight -> model.model.embed_tokens.weight
embed_tokens.weight -> model.model.embed_tokens.weight
model.layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
model.norm.weight -> model.model.norm.weight
norm.weight -> model.model.norm.weight
"""
def ace_step_lm_converter(state_dict):
"""
Convert ACE-Step LLM checkpoint keys to match Qwen3ForCausalLM state dict.
参数 state_dict 是 DiskMap 类型。
遍历时key 是 key 名state_dict[key] 获取实际值。
"""
new_state_dict = {}
model_prefix = "model."
nested_prefix = "model.model."
for key in state_dict:
if key.startswith(nested_prefix):
# Already has model.model., keep as is
new_key = key
elif key.startswith(model_prefix):
# Has model., add another model.
new_key = "model." + key
else:
# No prefix, add model.model.
new_key = "model.model." + key
new_state_dict[new_key] = state_dict[key]
# Handle tied word embeddings: lm_head.weight shares with embed_tokens
if "model.model.embed_tokens.weight" in new_state_dict:
new_state_dict["model.lm_head.weight"] = new_state_dict["model.model.embed_tokens.weight"]
return new_state_dict

View File

@@ -1,28 +1,4 @@
"""
State dict converter for ACE-Step Text Encoder (Qwen3-Embedding-0.6B).
The safetensors stores Qwen3Model weights with keys:
embed_tokens.weight
layers.0.self_attn.q_proj.weight
norm.weight
AceStepTextEncoder wraps a .model attribute (Qwen3Model), so its
state_dict() has keys with 'model.' prefix:
model.embed_tokens.weight
model.layers.0.self_attn.q_proj.weight
model.norm.weight
This converter adds 'model.' prefix to match the nested structure.
"""
def ace_step_text_encoder_converter(state_dict):
"""
Convert ACE-Step Text Encoder checkpoint keys to match Qwen3Model wrapped state dict.
参数 state_dict 是 DiskMap 类型。
遍历时key 是 key 名state_dict[key] 获取实际值。
"""
def AceStepTextEncoderStateDictConverter(state_dict):
new_state_dict = {}
prefix = "model."
nested_prefix = "model.model."

View File

@@ -1,23 +1,4 @@
"""
State dict converter for ACE-Step Tokenizer model.
The original checkpoint stores tokenizer and detokenizer weights at the top level:
- tokenizer.* (AceStepAudioTokenizer: audio_acoustic_proj, attention_pooler, quantizer)
- detokenizer.* (AudioTokenDetokenizer: embed_tokens, layers, proj_out)
These map directly to the AceStepTokenizer class which wraps both as
self.tokenizer and self.detokenizer submodules.
"""
def ace_step_tokenizer_converter(state_dict):
"""
Convert ACE-Step Tokenizer checkpoint keys to DiffSynth format.
The checkpoint keys `tokenizer.*` and `detokenizer.*` already match
the DiffSynth AceStepTokenizer module structure (self.tokenizer, self.detokenizer).
No key remapping needed — just extract the relevant keys.
"""
def AceStepTokenizerStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:

View File

@@ -1,5 +1,6 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from modelscope import dataset_snapshot_download
import torch
pipe = AceStepPipeline.from_pretrained(
@@ -29,6 +30,11 @@ audio = pipe(
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
# input audio codes as reference
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt",
)
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
audio_code_string = f.read().strip()

View File

@@ -1,5 +1,6 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
from modelscope import dataset_snapshot_download
import torch
pipe = AceStepPipeline.from_pretrained(
@@ -15,6 +16,13 @@ pipe = AceStepPipeline.from_pretrained(
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
)
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.

View File

@@ -1,5 +1,6 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
from modelscope import dataset_snapshot_download
import torch
pipe = AceStepPipeline.from_pretrained(
@@ -15,6 +16,13 @@ pipe = AceStepPipeline.from_pretrained(
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
)
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.

View File

@@ -6,6 +6,7 @@ Turbo model: uses num_inference_steps=8, cfg_scale=1.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from modelscope import dataset_snapshot_download
import torch
@@ -49,6 +50,11 @@ audio = pipe(
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-low-vram.wav")
# input audio codes as reference
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt",
)
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
audio_code_string = f.read().strip()

View File

@@ -1,5 +1,6 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
from modelscope import dataset_snapshot_download
import torch
vram_config = {
@@ -27,6 +28,13 @@ pipe = AceStepPipeline.from_pretrained(
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
)
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.

View File

@@ -1,5 +1,6 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
from modelscope import dataset_snapshot_download
import torch
vram_config = {
@@ -27,6 +28,13 @@ pipe = AceStepPipeline.from_pretrained(
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
)
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.

View File

@@ -1,31 +1,15 @@
import torch, os, argparse, accelerate, warnings, torchaudio
import os
import torch
import math
import argparse
import accelerate
from diffsynth.core import UnifiedDataset
from diffsynth.core.data.operators import ToAbsolutePath, RouteByType, DataProcessingOperator, LoadPureAudioWithTorchaudio
from diffsynth.core.data.operators import ToAbsolutePath, LoadPureAudioWithTorchaudio
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.diffusion import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class LoadAceStepAudio(DataProcessingOperator):
"""Load audio file and return waveform tensor [2, T] at 48kHz."""
def __init__(self, target_sr=48000):
self.target_sr = target_sr
def __call__(self, data: str):
try:
waveform, sample_rate = torchaudio.load(data)
if sample_rate != self.target_sr:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sr)
waveform = resampler(waveform)
if waveform.shape[0] == 1:
waveform = waveform.repeat(2, 1)
return waveform
except Exception as e:
warnings.warn(f"Cannot load audio from {data}: {e}")
return None
class AceStepTrainingModule(DiffusionTrainingModule):
def __init__(
self,
@@ -43,17 +27,15 @@ class AceStepTrainingModule(DiffusionTrainingModule):
task="sft",
):
super().__init__()
# ===== 解析模型配置(固定写法) =====
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
# ===== Tokenizer 配置 =====
text_tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"))
silence_latent_config = self.parse_path_or_model_id(silence_latent_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"))
# ===== 构建 Pipeline =====
self.pipe = AceStepPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config)
# ===== 拆分 Pipeline Units固定写法 =====
self.pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16, device=device, model_configs=model_configs,
text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config,
)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# ===== 切换到训练模式(固定写法) =====
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
@@ -61,13 +43,11 @@ class AceStepTrainingModule(DiffusionTrainingModule):
task=task,
)
# ===== 其他配置(固定写法) =====
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
# ===== 任务模式路由(固定写法) =====
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
@@ -78,11 +58,8 @@ class AceStepTrainingModule(DiffusionTrainingModule):
inputs_posi = {"prompt": data["prompt"], "positive": True}
inputs_nega = {"positive": False}
duration = math.floor(data['audio'][0].shape[1] / data['audio'][1]) if data.get("audio") is not None else data.get("duration", 60)
# ===== 共享参数 =====
inputs_shared = {
# ===== 核心字段映射 =====
"input_audio": data["audio"],
# ===== 音频生成任务所需元数据 =====
"lyrics": data["lyrics"],
"task_type": "text2music",
"duration": duration,
@@ -90,18 +67,15 @@ class AceStepTrainingModule(DiffusionTrainingModule):
"keyscale": data.get("keyscale", "C major"),
"timesignature": data.get("timesignature", "4"),
"vocal_language": data.get("vocal_language", "unknown"),
# ===== 框架控制参数(固定写法) =====
"cfg_scale": 1,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
# ===== 额外字段注入:通过 --extra_inputs 配置的数据集列名(固定写法) =====
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
# ===== 标准实现,不要修改(固定写法) =====
if inputs is None: inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
@@ -122,12 +96,10 @@ def ace_step_parser():
if __name__ == "__main__":
parser = ace_step_parser()
args = parser.parse_args()
# ===== Accelerator 配置(固定写法) =====
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
)
# ===== 数据集定义 =====
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
@@ -135,10 +107,11 @@ if __name__ == "__main__":
data_file_keys=args.data_file_keys.split(","),
main_data_operator=None,
special_operator_map={
"audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(target_sample_rate=48000),
"audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(
target_sample_rate=48000,
),
},
)
# ===== TrainingModule =====
model = AceStepTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
@@ -159,12 +132,10 @@ if __name__ == "__main__":
task=args.task,
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
)
# ===== ModelLogger固定写法 =====
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
# ===== 任务路由(固定写法) =====
launcher_map = {
"sft:data_process": launch_data_process_task,
"sft": launch_training_task,