From a80fb8422005c02f9b4960eddab33e6b749e25a9 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 23 Apr 2026 17:31:34 +0800 Subject: [PATCH] style --- diffsynth/configs/model_configs.py | 19 ++++--- diffsynth/pipelines/ace_step.py | 52 ++++++++++-------- .../ace_step_conditioner.py | 37 +------------ .../state_dict_converters/ace_step_dit.py | 35 +----------- .../state_dict_converters/ace_step_lm.py | 55 ------------------- .../ace_step_text_encoder.py | 26 +-------- .../ace_step_tokenizer.py | 21 +------ .../ace_step/model_inference/Ace-Step1.5.py | 6 ++ .../acestep-v15-base-CoverTask.py | 8 +++ .../acestep-v15-base-RepaintTask.py | 8 +++ .../model_inference_low_vram/Ace-Step1.5.py | 6 ++ .../acestep-v15-base-CoverTask.py | 8 +++ .../acestep-v15-base-RepaintTask.py | 8 +++ examples/ace_step/model_training/train.py | 53 ++++-------------- 14 files changed, 99 insertions(+), 243 deletions(-) delete mode 100644 diffsynth/utils/state_dict_converters/ace_step_lm.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index ccaa146..9c26a3c 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -925,7 +925,7 @@ ace_step_series = [ "model_hash": "ba29d8bddbb6ace65675f6a757a13c00", "model_name": "ace_step_dit", "model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel", - "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter", }, # === XL DiT variants (32 layers, hidden_size=2560) === # Covers: xl-base, xl-sft, xl-turbo @@ -934,7 +934,7 @@ ace_step_series = [ "model_hash": "3a28a410c2246f125153ef792d8bc828", "model_name": "ace_step_dit", "model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel", - "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter", "extra_kwargs": { "hidden_size": 2560, "intermediate_size": 9728, @@ -952,7 +952,7 @@ ace_step_series = [ "model_hash": "ba29d8bddbb6ace65675f6a757a13c00", "model_name": "ace_step_conditioner", "model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder", - "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter", }, # === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) === { @@ -960,7 +960,7 @@ ace_step_series = [ "model_hash": "3a28a410c2246f125153ef792d8bc828", "model_name": "ace_step_conditioner", "model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder", - "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter", }, # === Qwen3-Embedding (text encoder) === { @@ -968,7 +968,7 @@ ace_step_series = [ "model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0", "model_name": "ace_step_text_encoder", "model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder", - "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.ace_step_text_encoder_converter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.AceStepTextEncoderStateDictConverter", }, # === VAE (AutoencoderOobleck CNN) === { @@ -983,7 +983,7 @@ ace_step_series = [ "model_hash": "ba29d8bddbb6ace65675f6a757a13c00", "model_name": "ace_step_tokenizer", "model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer", - "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter", }, # === XL Tokenizer (XL models share same tokenizer architecture) === { @@ -991,8 +991,11 @@ ace_step_series = [ "model_hash": "3a28a410c2246f125153ef792d8bc828", "model_name": "ace_step_tokenizer", "model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer", - "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter", }, ] -MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series +MODEL_CONFIGS = ( + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series +) diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py index 317a533..4d9971a 100644 --- a/diffsynth/pipelines/ace_step.py +++ b/diffsynth/pipelines/ace_step.py @@ -3,12 +3,10 @@ ACE-Step Pipeline for DiffSynth-Studio. Text-to-Music generation pipeline using ACE-Step 1.5 model. """ -import re -import torch +import re, torch from typing import Optional, Dict, Any, List, Tuple from tqdm import tqdm -import random -import math +import random, math import torch.nn.functional as F from einops import rearrange @@ -39,7 +37,7 @@ class AceStepPipeline(BasePipeline): self.conditioner: AceStepConditionEncoder = None self.dit: AceStepDiTModel = None self.vae: AceStepVAE = None - self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer) + self.tokenizer_model: AceStepTokenizer = None self.in_iteration_models = ("dit",) self.units = [ @@ -65,7 +63,6 @@ class AceStepPipeline(BasePipeline): silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"), vram_limit: float = None, ): - """Load pipeline from pretrained checkpoints.""" pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype) model_pool = pipe.download_and_load_models(model_configs, vram_limit) @@ -102,7 +99,7 @@ class AceStepPipeline(BasePipeline): reference_audios: List[torch.Tensor] = None, # Source audio src_audio: torch.Tensor = None, - denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength + denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength audio_cover_strength: float = 1.0, # Audio codes audio_code_string: Optional[str] = None, @@ -115,7 +112,7 @@ class AceStepPipeline(BasePipeline): bpm: Optional[int] = 100, keyscale: Optional[str] = "B minor", timesignature: Optional[str] = "4", - vocal_language: Optional[str] = 'unknown', + vocal_language: Optional[str] = "unknown", # Randomness seed: int = None, rand_device: str = "cpu", @@ -126,10 +123,10 @@ class AceStepPipeline(BasePipeline): # Progress progress_bar_cmd=tqdm, ): - # 1. Scheduler + # Scheduler self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift) - # 2. 三字典输入 + # Parameters inputs_posi = {"prompt": prompt, "positive": True} inputs_nega = {"positive": False} inputs_shared = { @@ -147,13 +144,12 @@ class AceStepPipeline(BasePipeline): "shift": shift, } - # 3. Unit 链执行 for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner( unit, self, inputs_shared, inputs_posi, inputs_nega ) - # 4. Denoise loop + # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): @@ -164,17 +160,17 @@ class AceStepPipeline(BasePipeline): inputs_shared, inputs_posi, inputs_nega, **models, timestep=timestep, progress_id=progress_id, ) - inputs_shared["latents"] = self.step(self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None), - progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + inputs_shared["latents"] = self.step( + self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None), + progress_id=progress_id, noise_pred=noise_pred, **inputs_shared, + ) - # 5. VAE 解码 + # Decode self.load_models_to_device(['vae']) # DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first) latents = inputs_shared["latents"].transpose(1, 2) vae_output = self.vae.decode(latents) - # VAE returns OobleckDecoderOutput with .sample attribute - audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output - audio_output = self.normalize_audio(audio_output, target_db=-1.0) + audio_output = self.normalize_audio(vae_output, target_db=-1.0) audio = self.output_audio_format_check(audio_output) self.load_models_to_device([]) return audio @@ -188,7 +184,9 @@ class AceStepPipeline(BasePipeline): return audio * gain def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id): - if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0 or inputs_shared.get("shared_noncover", None) is None: + if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0: + return + if inputs_shared.get("shared_noncover", None) is None: return cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"]) if progress_id >= cover_steps: @@ -312,7 +310,10 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit): def process(self, pipe, reference_audios): if reference_audios is not None: pipe.load_models_to_device(['vae']) - reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios] + reference_audios = [ + self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) + for reference_audio in reference_audios + ] reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios]) else: reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]] @@ -357,7 +358,6 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit): class AceStepUnit_ConditionEmbedder(PipelineUnit): - def __init__(self): super().__init__( take_over=True, @@ -378,7 +378,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit): inputs_posi["encoder_hidden_states"] = encoder_hidden_states inputs_posi["encoder_attention_mask"] = encoder_attention_mask if inputs_shared["cfg_scale"] != 1.0: - inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device) + inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to( + dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device, + ) inputs_nega["encoder_attention_mask"] = encoder_attention_mask if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0: hidden_states_noncover = AceStepUnit_PromptEmbedder().process( @@ -396,7 +398,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit): inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover} if inputs_shared["cfg_scale"] != 1.0: inputs_shared["nega_noncover"] = { - "encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device), + "encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to( + dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device, + ), "encoder_attention_mask": encoder_attention_mask_noncover, } return inputs_shared, inputs_posi, inputs_nega @@ -483,7 +487,7 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit): # use audio_cede_string to get src_latents. pipe.load_models_to_device(self.onload_model_names) code_ids = self._parse_audio_code_string(audio_code_string) - quantizer = pipe.tokenizer_model.tokenizer.quantizer + quantizer = pipe.tokenizer_model.tokenizer.quantizer.to(device=pipe.device) indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1) codes = quantizer.get_codes_from_indices(indices) quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device) diff --git a/diffsynth/utils/state_dict_converters/ace_step_conditioner.py b/diffsynth/utils/state_dict_converters/ace_step_conditioner.py index b6984b8..041a405 100644 --- a/diffsynth/utils/state_dict_converters/ace_step_conditioner.py +++ b/diffsynth/utils/state_dict_converters/ace_step_conditioner.py @@ -1,38 +1,4 @@ -""" -State dict converter for ACE-Step Conditioner model. - -The original checkpoint stores all model weights in a single file -(nested in AceStepConditionGenerationModel). The Conditioner weights are -prefixed with 'encoder.'. - -This converter extracts only keys starting with 'encoder.' and strips -the prefix to match the standalone AceStepConditionEncoder in DiffSynth. -""" - - -def ace_step_conditioner_converter(state_dict): - """ - Convert ACE-Step Conditioner checkpoint keys to DiffSynth format. - - 参数 state_dict 是 DiskMap 类型。 - 遍历时,key 是 key 名,state_dict[key] 获取实际值。 - - Original checkpoint contains all model weights under prefixes: - - decoder.* (DiT) - - encoder.* (Conditioner) - - tokenizer.* (Audio Tokenizer) - - detokenizer.* (Audio Detokenizer) - - null_condition_emb (CFG null embedding) - - This extracts only 'encoder.' keys and strips the prefix. - - Example mapping: - encoder.lyric_encoder.layers.0.self_attn.q_proj.weight -> lyric_encoder.layers.0.self_attn.q_proj.weight - encoder.attention_pooler.layers.0.self_attn.q_proj.weight -> attention_pooler.layers.0.self_attn.q_proj.weight - encoder.timbre_encoder.layers.0.self_attn.q_proj.weight -> timbre_encoder.layers.0.self_attn.q_proj.weight - encoder.audio_tokenizer.audio_acoustic_proj.weight -> audio_tokenizer.audio_acoustic_proj.weight - encoder.detokenizer.layers.0.self_attn.q_proj.weight -> detokenizer.layers.0.self_attn.q_proj.weight - """ +def AceStepConditionEncoderStateDictConverter(state_dict): new_state_dict = {} prefix = "encoder." @@ -41,7 +7,6 @@ def ace_step_conditioner_converter(state_dict): new_key = key[len(prefix):] new_state_dict[new_key] = state_dict[key] - # Extract null_condition_emb from top level (used for CFG negative condition) if "null_condition_emb" in state_dict: new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"] diff --git a/diffsynth/utils/state_dict_converters/ace_step_dit.py b/diffsynth/utils/state_dict_converters/ace_step_dit.py index 758462c..d5f7cf6 100644 --- a/diffsynth/utils/state_dict_converters/ace_step_dit.py +++ b/diffsynth/utils/state_dict_converters/ace_step_dit.py @@ -1,37 +1,4 @@ -""" -State dict converter for ACE-Step DiT model. - -The original checkpoint stores all model weights in a single file -(nested in AceStepConditionGenerationModel). The DiT weights are -prefixed with 'decoder.'. - -This converter extracts only keys starting with 'decoder.' and strips -the prefix to match the standalone AceStepDiTModel in DiffSynth. -""" - - -def ace_step_dit_converter(state_dict): - """ - Convert ACE-Step DiT checkpoint keys to DiffSynth format. - - 参数 state_dict 是 DiskMap 类型。 - 遍历时,key 是 key 名,state_dict[key] 获取实际值。 - - Original checkpoint contains all model weights under prefixes: - - decoder.* (DiT) - - encoder.* (Conditioner) - - tokenizer.* (Audio Tokenizer) - - detokenizer.* (Audio Detokenizer) - - null_condition_emb (CFG null embedding) - - This extracts only 'decoder.' keys and strips the prefix. - - Example mapping: - decoder.layers.0.self_attn.q_proj.weight -> layers.0.self_attn.q_proj.weight - decoder.proj_in.0.linear_1.weight -> proj_in.0.linear_1.weight - decoder.time_embed.linear_1.weight -> time_embed.linear_1.weight - decoder.rotary_emb.inv_freq -> rotary_emb.inv_freq - """ +def AceStepDiTModelStateDictConverter(state_dict): new_state_dict = {} prefix = "decoder." diff --git a/diffsynth/utils/state_dict_converters/ace_step_lm.py b/diffsynth/utils/state_dict_converters/ace_step_lm.py deleted file mode 100644 index 2067cb1..0000000 --- a/diffsynth/utils/state_dict_converters/ace_step_lm.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -State dict converter for ACE-Step LLM (Qwen3-based). - -The safetensors file stores Qwen3 model weights. Different checkpoints -may have different key formats: -- Qwen3ForCausalLM format: model.embed_tokens.weight, model.layers.0.* -- Qwen3Model format: embed_tokens.weight, layers.0.* - -Qwen3ForCausalLM wraps a .model attribute (Qwen3Model), so its -state_dict() has keys: - model.model.embed_tokens.weight - model.model.layers.0.self_attn.q_proj.weight - model.model.norm.weight - model.lm_head.weight (tied to model.model.embed_tokens) - -This converter normalizes all keys to the Qwen3ForCausalLM format. - -Example mapping: - model.embed_tokens.weight -> model.model.embed_tokens.weight - embed_tokens.weight -> model.model.embed_tokens.weight - model.layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight - layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight - model.norm.weight -> model.model.norm.weight - norm.weight -> model.model.norm.weight -""" - - -def ace_step_lm_converter(state_dict): - """ - Convert ACE-Step LLM checkpoint keys to match Qwen3ForCausalLM state dict. - - 参数 state_dict 是 DiskMap 类型。 - 遍历时,key 是 key 名,state_dict[key] 获取实际值。 - """ - new_state_dict = {} - model_prefix = "model." - nested_prefix = "model.model." - - for key in state_dict: - if key.startswith(nested_prefix): - # Already has model.model., keep as is - new_key = key - elif key.startswith(model_prefix): - # Has model., add another model. - new_key = "model." + key - else: - # No prefix, add model.model. - new_key = "model.model." + key - new_state_dict[new_key] = state_dict[key] - - # Handle tied word embeddings: lm_head.weight shares with embed_tokens - if "model.model.embed_tokens.weight" in new_state_dict: - new_state_dict["model.lm_head.weight"] = new_state_dict["model.model.embed_tokens.weight"] - - return new_state_dict diff --git a/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py b/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py index de0b6c7..4ed1c01 100644 --- a/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py +++ b/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py @@ -1,28 +1,4 @@ -""" -State dict converter for ACE-Step Text Encoder (Qwen3-Embedding-0.6B). - -The safetensors stores Qwen3Model weights with keys: - embed_tokens.weight - layers.0.self_attn.q_proj.weight - norm.weight - -AceStepTextEncoder wraps a .model attribute (Qwen3Model), so its -state_dict() has keys with 'model.' prefix: - model.embed_tokens.weight - model.layers.0.self_attn.q_proj.weight - model.norm.weight - -This converter adds 'model.' prefix to match the nested structure. -""" - - -def ace_step_text_encoder_converter(state_dict): - """ - Convert ACE-Step Text Encoder checkpoint keys to match Qwen3Model wrapped state dict. - - 参数 state_dict 是 DiskMap 类型。 - 遍历时,key 是 key 名,state_dict[key] 获取实际值。 - """ +def AceStepTextEncoderStateDictConverter(state_dict): new_state_dict = {} prefix = "model." nested_prefix = "model.model." diff --git a/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py b/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py index d4cb2ba..66e014c 100644 --- a/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py +++ b/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py @@ -1,23 +1,4 @@ -""" -State dict converter for ACE-Step Tokenizer model. - -The original checkpoint stores tokenizer and detokenizer weights at the top level: -- tokenizer.* (AceStepAudioTokenizer: audio_acoustic_proj, attention_pooler, quantizer) -- detokenizer.* (AudioTokenDetokenizer: embed_tokens, layers, proj_out) - -These map directly to the AceStepTokenizer class which wraps both as -self.tokenizer and self.detokenizer submodules. -""" - - -def ace_step_tokenizer_converter(state_dict): - """ - Convert ACE-Step Tokenizer checkpoint keys to DiffSynth format. - - The checkpoint keys `tokenizer.*` and `detokenizer.*` already match - the DiffSynth AceStepTokenizer module structure (self.tokenizer, self.detokenizer). - No key remapping needed — just extract the relevant keys. - """ +def AceStepTokenizerStateDictConverter(state_dict): new_state_dict = {} for key in state_dict: diff --git a/examples/ace_step/model_inference/Ace-Step1.5.py b/examples/ace_step/model_inference/Ace-Step1.5.py index ae41f11..ff40d88 100644 --- a/examples/ace_step/model_inference/Ace-Step1.5.py +++ b/examples/ace_step/model_inference/Ace-Step1.5.py @@ -1,5 +1,6 @@ from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig from diffsynth.utils.data.audio import save_audio +from modelscope import dataset_snapshot_download import torch pipe = AceStepPipeline.from_pretrained( @@ -29,6 +30,11 @@ audio = pipe( save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav") # input audio codes as reference +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt", +) with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f: audio_code_string = f.read().strip() diff --git a/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py b/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py index 3f55aa4..9b8a8a2 100644 --- a/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py +++ b/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py @@ -1,5 +1,6 @@ from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig from diffsynth.utils.data.audio import save_audio, read_audio +from modelscope import dataset_snapshot_download import torch pipe = AceStepPipeline.from_pretrained( @@ -15,6 +16,13 @@ pipe = AceStepPipeline.from_pretrained( prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating." lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]' + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="ace_step/acestep-v15-base/audio.wav", +) + src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate) # audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps. # denoising_strength controls how the output audio is influenced by the source audio in cover tasks. diff --git a/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py b/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py index 4915245..68a0a36 100644 --- a/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py +++ b/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py @@ -1,5 +1,6 @@ from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig from diffsynth.utils.data.audio import save_audio, read_audio +from modelscope import dataset_snapshot_download import torch pipe = AceStepPipeline.from_pretrained( @@ -15,6 +16,13 @@ pipe = AceStepPipeline.from_pretrained( prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating." lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]' + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="ace_step/acestep-v15-base/audio.wav", +) + src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate) # repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio. # For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence. diff --git a/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py b/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py index 4bc2e5e..0160bcf 100644 --- a/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py +++ b/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py @@ -6,6 +6,7 @@ Turbo model: uses num_inference_steps=8, cfg_scale=1.0. """ from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig from diffsynth.utils.data.audio import save_audio +from modelscope import dataset_snapshot_download import torch @@ -49,6 +50,11 @@ audio = pipe( save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-low-vram.wav") # input audio codes as reference +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt", +) with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f: audio_code_string = f.read().strip() diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py b/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py index f16a4bd..2ae06fe 100644 --- a/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py +++ b/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py @@ -1,5 +1,6 @@ from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig from diffsynth.utils.data.audio import save_audio, read_audio +from modelscope import dataset_snapshot_download import torch vram_config = { @@ -27,6 +28,13 @@ pipe = AceStepPipeline.from_pretrained( prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating." lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]' + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="ace_step/acestep-v15-base/audio.wav", +) + src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate) # audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps. # denoising_strength controls how the output audio is influenced by the source audio in cover tasks. diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py b/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py index 42a3c2b..6cbe107 100644 --- a/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py +++ b/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py @@ -1,5 +1,6 @@ from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig from diffsynth.utils.data.audio import save_audio, read_audio +from modelscope import dataset_snapshot_download import torch vram_config = { @@ -27,6 +28,13 @@ pipe = AceStepPipeline.from_pretrained( prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating." lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]' + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="ace_step/acestep-v15-base/audio.wav", +) + src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate) # repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio. # For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence. diff --git a/examples/ace_step/model_training/train.py b/examples/ace_step/model_training/train.py index a24da2c..21d57dd 100644 --- a/examples/ace_step/model_training/train.py +++ b/examples/ace_step/model_training/train.py @@ -1,31 +1,15 @@ -import torch, os, argparse, accelerate, warnings, torchaudio +import os +import torch import math +import argparse +import accelerate from diffsynth.core import UnifiedDataset -from diffsynth.core.data.operators import ToAbsolutePath, RouteByType, DataProcessingOperator, LoadPureAudioWithTorchaudio +from diffsynth.core.data.operators import ToAbsolutePath, LoadPureAudioWithTorchaudio from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig from diffsynth.diffusion import * os.environ["TOKENIZERS_PARALLELISM"] = "false" -class LoadAceStepAudio(DataProcessingOperator): - """Load audio file and return waveform tensor [2, T] at 48kHz.""" - def __init__(self, target_sr=48000): - self.target_sr = target_sr - - def __call__(self, data: str): - try: - waveform, sample_rate = torchaudio.load(data) - if sample_rate != self.target_sr: - resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sr) - waveform = resampler(waveform) - if waveform.shape[0] == 1: - waveform = waveform.repeat(2, 1) - return waveform - except Exception as e: - warnings.warn(f"Cannot load audio from {data}: {e}") - return None - - class AceStepTrainingModule(DiffusionTrainingModule): def __init__( self, @@ -43,17 +27,15 @@ class AceStepTrainingModule(DiffusionTrainingModule): task="sft", ): super().__init__() - # ===== 解析模型配置(固定写法) ===== model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) - # ===== Tokenizer 配置 ===== text_tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/")) silence_latent_config = self.parse_path_or_model_id(silence_latent_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt")) - # ===== 构建 Pipeline ===== - self.pipe = AceStepPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config) - # ===== 拆分 Pipeline Units(固定写法) ===== + self.pipe = AceStepPipeline.from_pretrained( + torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, + text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config, + ) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) - # ===== 切换到训练模式(固定写法) ===== self.switch_pipe_to_training_mode( self.pipe, trainable_models, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, @@ -61,13 +43,11 @@ class AceStepTrainingModule(DiffusionTrainingModule): task=task, ) - # ===== 其他配置(固定写法) ===== self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] self.fp8_models = fp8_models self.task = task - # ===== 任务模式路由(固定写法) ===== self.task_to_loss = { "sft:data_process": lambda pipe, *args: args, "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), @@ -78,11 +58,8 @@ class AceStepTrainingModule(DiffusionTrainingModule): inputs_posi = {"prompt": data["prompt"], "positive": True} inputs_nega = {"positive": False} duration = math.floor(data['audio'][0].shape[1] / data['audio'][1]) if data.get("audio") is not None else data.get("duration", 60) - # ===== 共享参数 ===== inputs_shared = { - # ===== 核心字段映射 ===== "input_audio": data["audio"], - # ===== 音频生成任务所需元数据 ===== "lyrics": data["lyrics"], "task_type": "text2music", "duration": duration, @@ -90,18 +67,15 @@ class AceStepTrainingModule(DiffusionTrainingModule): "keyscale": data.get("keyscale", "C major"), "timesignature": data.get("timesignature", "4"), "vocal_language": data.get("vocal_language", "unknown"), - # ===== 框架控制参数(固定写法) ===== "cfg_scale": 1, "rand_device": self.pipe.device, "use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, } - # ===== 额外字段注入:通过 --extra_inputs 配置的数据集列名(固定写法) ===== inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) return inputs_shared, inputs_posi, inputs_nega def forward(self, data, inputs=None): - # ===== 标准实现,不要修改(固定写法) ===== if inputs is None: inputs = self.get_pipeline_inputs(data) inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) for unit in self.pipe.units: @@ -122,12 +96,10 @@ def ace_step_parser(): if __name__ == "__main__": parser = ace_step_parser() args = parser.parse_args() - # ===== Accelerator 配置(固定写法) ===== accelerator = accelerate.Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], ) - # ===== 数据集定义 ===== dataset = UnifiedDataset( base_path=args.dataset_base_path, metadata_path=args.dataset_metadata_path, @@ -135,10 +107,11 @@ if __name__ == "__main__": data_file_keys=args.data_file_keys.split(","), main_data_operator=None, special_operator_map={ - "audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(target_sample_rate=48000), + "audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio( + target_sample_rate=48000, + ), }, ) - # ===== TrainingModule ===== model = AceStepTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, @@ -159,12 +132,10 @@ if __name__ == "__main__": task=args.task, device="cpu" if args.initialize_model_on_cpu else accelerator.device, ) - # ===== ModelLogger(固定写法) ===== model_logger = ModelLogger( args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, ) - # ===== 任务路由(固定写法) ===== launcher_map = { "sft:data_process": launch_data_process_task, "sft": launch_training_task,