mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
style
This commit is contained in:
@@ -925,7 +925,7 @@ ace_step_series = [
|
|||||||
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
||||||
"model_name": "ace_step_dit",
|
"model_name": "ace_step_dit",
|
||||||
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
|
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
|
||||||
},
|
},
|
||||||
# === XL DiT variants (32 layers, hidden_size=2560) ===
|
# === XL DiT variants (32 layers, hidden_size=2560) ===
|
||||||
# Covers: xl-base, xl-sft, xl-turbo
|
# Covers: xl-base, xl-sft, xl-turbo
|
||||||
@@ -934,7 +934,7 @@ ace_step_series = [
|
|||||||
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
||||||
"model_name": "ace_step_dit",
|
"model_name": "ace_step_dit",
|
||||||
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
|
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
|
||||||
"extra_kwargs": {
|
"extra_kwargs": {
|
||||||
"hidden_size": 2560,
|
"hidden_size": 2560,
|
||||||
"intermediate_size": 9728,
|
"intermediate_size": 9728,
|
||||||
@@ -952,7 +952,7 @@ ace_step_series = [
|
|||||||
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
||||||
"model_name": "ace_step_conditioner",
|
"model_name": "ace_step_conditioner",
|
||||||
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
|
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
|
||||||
},
|
},
|
||||||
# === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) ===
|
# === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) ===
|
||||||
{
|
{
|
||||||
@@ -960,7 +960,7 @@ ace_step_series = [
|
|||||||
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
||||||
"model_name": "ace_step_conditioner",
|
"model_name": "ace_step_conditioner",
|
||||||
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
|
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
|
||||||
},
|
},
|
||||||
# === Qwen3-Embedding (text encoder) ===
|
# === Qwen3-Embedding (text encoder) ===
|
||||||
{
|
{
|
||||||
@@ -968,7 +968,7 @@ ace_step_series = [
|
|||||||
"model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
|
"model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
|
||||||
"model_name": "ace_step_text_encoder",
|
"model_name": "ace_step_text_encoder",
|
||||||
"model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
|
"model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.ace_step_text_encoder_converter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.AceStepTextEncoderStateDictConverter",
|
||||||
},
|
},
|
||||||
# === VAE (AutoencoderOobleck CNN) ===
|
# === VAE (AutoencoderOobleck CNN) ===
|
||||||
{
|
{
|
||||||
@@ -983,7 +983,7 @@ ace_step_series = [
|
|||||||
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
||||||
"model_name": "ace_step_tokenizer",
|
"model_name": "ace_step_tokenizer",
|
||||||
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
|
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
|
||||||
},
|
},
|
||||||
# === XL Tokenizer (XL models share same tokenizer architecture) ===
|
# === XL Tokenizer (XL models share same tokenizer architecture) ===
|
||||||
{
|
{
|
||||||
@@ -991,8 +991,11 @@ ace_step_series = [
|
|||||||
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
||||||
"model_name": "ace_step_tokenizer",
|
"model_name": "ace_step_tokenizer",
|
||||||
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
|
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
|
MODEL_CONFIGS = (
|
||||||
|
qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
|
||||||
|
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
|
||||||
|
)
|
||||||
|
|||||||
@@ -3,12 +3,10 @@ ACE-Step Pipeline for DiffSynth-Studio.
|
|||||||
|
|
||||||
Text-to-Music generation pipeline using ACE-Step 1.5 model.
|
Text-to-Music generation pipeline using ACE-Step 1.5 model.
|
||||||
"""
|
"""
|
||||||
import re
|
import re, torch
|
||||||
import torch
|
|
||||||
from typing import Optional, Dict, Any, List, Tuple
|
from typing import Optional, Dict, Any, List, Tuple
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import random
|
import random, math
|
||||||
import math
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
@@ -39,7 +37,7 @@ class AceStepPipeline(BasePipeline):
|
|||||||
self.conditioner: AceStepConditionEncoder = None
|
self.conditioner: AceStepConditionEncoder = None
|
||||||
self.dit: AceStepDiTModel = None
|
self.dit: AceStepDiTModel = None
|
||||||
self.vae: AceStepVAE = None
|
self.vae: AceStepVAE = None
|
||||||
self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer)
|
self.tokenizer_model: AceStepTokenizer = None
|
||||||
|
|
||||||
self.in_iteration_models = ("dit",)
|
self.in_iteration_models = ("dit",)
|
||||||
self.units = [
|
self.units = [
|
||||||
@@ -65,7 +63,6 @@ class AceStepPipeline(BasePipeline):
|
|||||||
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||||
vram_limit: float = None,
|
vram_limit: float = None,
|
||||||
):
|
):
|
||||||
"""Load pipeline from pretrained checkpoints."""
|
|
||||||
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
|
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
@@ -102,7 +99,7 @@ class AceStepPipeline(BasePipeline):
|
|||||||
reference_audios: List[torch.Tensor] = None,
|
reference_audios: List[torch.Tensor] = None,
|
||||||
# Source audio
|
# Source audio
|
||||||
src_audio: torch.Tensor = None,
|
src_audio: torch.Tensor = None,
|
||||||
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
|
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
|
||||||
audio_cover_strength: float = 1.0,
|
audio_cover_strength: float = 1.0,
|
||||||
# Audio codes
|
# Audio codes
|
||||||
audio_code_string: Optional[str] = None,
|
audio_code_string: Optional[str] = None,
|
||||||
@@ -115,7 +112,7 @@ class AceStepPipeline(BasePipeline):
|
|||||||
bpm: Optional[int] = 100,
|
bpm: Optional[int] = 100,
|
||||||
keyscale: Optional[str] = "B minor",
|
keyscale: Optional[str] = "B minor",
|
||||||
timesignature: Optional[str] = "4",
|
timesignature: Optional[str] = "4",
|
||||||
vocal_language: Optional[str] = 'unknown',
|
vocal_language: Optional[str] = "unknown",
|
||||||
# Randomness
|
# Randomness
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
rand_device: str = "cpu",
|
rand_device: str = "cpu",
|
||||||
@@ -126,10 +123,10 @@ class AceStepPipeline(BasePipeline):
|
|||||||
# Progress
|
# Progress
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
):
|
):
|
||||||
# 1. Scheduler
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
|
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
|
||||||
|
|
||||||
# 2. 三字典输入
|
# Parameters
|
||||||
inputs_posi = {"prompt": prompt, "positive": True}
|
inputs_posi = {"prompt": prompt, "positive": True}
|
||||||
inputs_nega = {"positive": False}
|
inputs_nega = {"positive": False}
|
||||||
inputs_shared = {
|
inputs_shared = {
|
||||||
@@ -147,13 +144,12 @@ class AceStepPipeline(BasePipeline):
|
|||||||
"shift": shift,
|
"shift": shift,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 3. Unit 链执行
|
|
||||||
for unit in self.units:
|
for unit in self.units:
|
||||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||||
unit, self, inputs_shared, inputs_posi, inputs_nega
|
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Denoise loop
|
# Denoise
|
||||||
self.load_models_to_device(self.in_iteration_models)
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
@@ -164,17 +160,17 @@ class AceStepPipeline(BasePipeline):
|
|||||||
inputs_shared, inputs_posi, inputs_nega,
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
**models, timestep=timestep, progress_id=progress_id,
|
**models, timestep=timestep, progress_id=progress_id,
|
||||||
)
|
)
|
||||||
inputs_shared["latents"] = self.step(self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
|
inputs_shared["latents"] = self.step(
|
||||||
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
|
||||||
|
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared,
|
||||||
|
)
|
||||||
|
|
||||||
# 5. VAE 解码
|
# Decode
|
||||||
self.load_models_to_device(['vae'])
|
self.load_models_to_device(['vae'])
|
||||||
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
|
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
|
||||||
latents = inputs_shared["latents"].transpose(1, 2)
|
latents = inputs_shared["latents"].transpose(1, 2)
|
||||||
vae_output = self.vae.decode(latents)
|
vae_output = self.vae.decode(latents)
|
||||||
# VAE returns OobleckDecoderOutput with .sample attribute
|
audio_output = self.normalize_audio(vae_output, target_db=-1.0)
|
||||||
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
|
|
||||||
audio_output = self.normalize_audio(audio_output, target_db=-1.0)
|
|
||||||
audio = self.output_audio_format_check(audio_output)
|
audio = self.output_audio_format_check(audio_output)
|
||||||
self.load_models_to_device([])
|
self.load_models_to_device([])
|
||||||
return audio
|
return audio
|
||||||
@@ -188,7 +184,9 @@ class AceStepPipeline(BasePipeline):
|
|||||||
return audio * gain
|
return audio * gain
|
||||||
|
|
||||||
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
|
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
|
||||||
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0 or inputs_shared.get("shared_noncover", None) is None:
|
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0:
|
||||||
|
return
|
||||||
|
if inputs_shared.get("shared_noncover", None) is None:
|
||||||
return
|
return
|
||||||
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
|
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
|
||||||
if progress_id >= cover_steps:
|
if progress_id >= cover_steps:
|
||||||
@@ -312,7 +310,10 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
|||||||
def process(self, pipe, reference_audios):
|
def process(self, pipe, reference_audios):
|
||||||
if reference_audios is not None:
|
if reference_audios is not None:
|
||||||
pipe.load_models_to_device(['vae'])
|
pipe.load_models_to_device(['vae'])
|
||||||
reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios]
|
reference_audios = [
|
||||||
|
self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
for reference_audio in reference_audios
|
||||||
|
]
|
||||||
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
|
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
|
||||||
else:
|
else:
|
||||||
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
|
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
|
||||||
@@ -357,7 +358,6 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
|||||||
|
|
||||||
|
|
||||||
class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
take_over=True,
|
take_over=True,
|
||||||
@@ -378,7 +378,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
|||||||
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
|
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
|
||||||
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
|
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
|
||||||
if inputs_shared["cfg_scale"] != 1.0:
|
if inputs_shared["cfg_scale"] != 1.0:
|
||||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
|
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(
|
||||||
|
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device,
|
||||||
|
)
|
||||||
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
||||||
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
|
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
|
||||||
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
|
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
|
||||||
@@ -396,7 +398,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
|||||||
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
|
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
|
||||||
if inputs_shared["cfg_scale"] != 1.0:
|
if inputs_shared["cfg_scale"] != 1.0:
|
||||||
inputs_shared["nega_noncover"] = {
|
inputs_shared["nega_noncover"] = {
|
||||||
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device),
|
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(
|
||||||
|
dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device,
|
||||||
|
),
|
||||||
"encoder_attention_mask": encoder_attention_mask_noncover,
|
"encoder_attention_mask": encoder_attention_mask_noncover,
|
||||||
}
|
}
|
||||||
return inputs_shared, inputs_posi, inputs_nega
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
@@ -483,7 +487,7 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
|||||||
# use audio_cede_string to get src_latents.
|
# use audio_cede_string to get src_latents.
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
code_ids = self._parse_audio_code_string(audio_code_string)
|
code_ids = self._parse_audio_code_string(audio_code_string)
|
||||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
quantizer = pipe.tokenizer_model.tokenizer.quantizer.to(device=pipe.device)
|
||||||
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
|
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
|
||||||
codes = quantizer.get_codes_from_indices(indices)
|
codes = quantizer.get_codes_from_indices(indices)
|
||||||
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
|
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
|
||||||
|
|||||||
@@ -1,38 +1,4 @@
|
|||||||
"""
|
def AceStepConditionEncoderStateDictConverter(state_dict):
|
||||||
State dict converter for ACE-Step Conditioner model.
|
|
||||||
|
|
||||||
The original checkpoint stores all model weights in a single file
|
|
||||||
(nested in AceStepConditionGenerationModel). The Conditioner weights are
|
|
||||||
prefixed with 'encoder.'.
|
|
||||||
|
|
||||||
This converter extracts only keys starting with 'encoder.' and strips
|
|
||||||
the prefix to match the standalone AceStepConditionEncoder in DiffSynth.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def ace_step_conditioner_converter(state_dict):
|
|
||||||
"""
|
|
||||||
Convert ACE-Step Conditioner checkpoint keys to DiffSynth format.
|
|
||||||
|
|
||||||
参数 state_dict 是 DiskMap 类型。
|
|
||||||
遍历时,key 是 key 名,state_dict[key] 获取实际值。
|
|
||||||
|
|
||||||
Original checkpoint contains all model weights under prefixes:
|
|
||||||
- decoder.* (DiT)
|
|
||||||
- encoder.* (Conditioner)
|
|
||||||
- tokenizer.* (Audio Tokenizer)
|
|
||||||
- detokenizer.* (Audio Detokenizer)
|
|
||||||
- null_condition_emb (CFG null embedding)
|
|
||||||
|
|
||||||
This extracts only 'encoder.' keys and strips the prefix.
|
|
||||||
|
|
||||||
Example mapping:
|
|
||||||
encoder.lyric_encoder.layers.0.self_attn.q_proj.weight -> lyric_encoder.layers.0.self_attn.q_proj.weight
|
|
||||||
encoder.attention_pooler.layers.0.self_attn.q_proj.weight -> attention_pooler.layers.0.self_attn.q_proj.weight
|
|
||||||
encoder.timbre_encoder.layers.0.self_attn.q_proj.weight -> timbre_encoder.layers.0.self_attn.q_proj.weight
|
|
||||||
encoder.audio_tokenizer.audio_acoustic_proj.weight -> audio_tokenizer.audio_acoustic_proj.weight
|
|
||||||
encoder.detokenizer.layers.0.self_attn.q_proj.weight -> detokenizer.layers.0.self_attn.q_proj.weight
|
|
||||||
"""
|
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
prefix = "encoder."
|
prefix = "encoder."
|
||||||
|
|
||||||
@@ -41,7 +7,6 @@ def ace_step_conditioner_converter(state_dict):
|
|||||||
new_key = key[len(prefix):]
|
new_key = key[len(prefix):]
|
||||||
new_state_dict[new_key] = state_dict[key]
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
|
||||||
# Extract null_condition_emb from top level (used for CFG negative condition)
|
|
||||||
if "null_condition_emb" in state_dict:
|
if "null_condition_emb" in state_dict:
|
||||||
new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"]
|
new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"]
|
||||||
|
|
||||||
|
|||||||
@@ -1,37 +1,4 @@
|
|||||||
"""
|
def AceStepDiTModelStateDictConverter(state_dict):
|
||||||
State dict converter for ACE-Step DiT model.
|
|
||||||
|
|
||||||
The original checkpoint stores all model weights in a single file
|
|
||||||
(nested in AceStepConditionGenerationModel). The DiT weights are
|
|
||||||
prefixed with 'decoder.'.
|
|
||||||
|
|
||||||
This converter extracts only keys starting with 'decoder.' and strips
|
|
||||||
the prefix to match the standalone AceStepDiTModel in DiffSynth.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def ace_step_dit_converter(state_dict):
|
|
||||||
"""
|
|
||||||
Convert ACE-Step DiT checkpoint keys to DiffSynth format.
|
|
||||||
|
|
||||||
参数 state_dict 是 DiskMap 类型。
|
|
||||||
遍历时,key 是 key 名,state_dict[key] 获取实际值。
|
|
||||||
|
|
||||||
Original checkpoint contains all model weights under prefixes:
|
|
||||||
- decoder.* (DiT)
|
|
||||||
- encoder.* (Conditioner)
|
|
||||||
- tokenizer.* (Audio Tokenizer)
|
|
||||||
- detokenizer.* (Audio Detokenizer)
|
|
||||||
- null_condition_emb (CFG null embedding)
|
|
||||||
|
|
||||||
This extracts only 'decoder.' keys and strips the prefix.
|
|
||||||
|
|
||||||
Example mapping:
|
|
||||||
decoder.layers.0.self_attn.q_proj.weight -> layers.0.self_attn.q_proj.weight
|
|
||||||
decoder.proj_in.0.linear_1.weight -> proj_in.0.linear_1.weight
|
|
||||||
decoder.time_embed.linear_1.weight -> time_embed.linear_1.weight
|
|
||||||
decoder.rotary_emb.inv_freq -> rotary_emb.inv_freq
|
|
||||||
"""
|
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
prefix = "decoder."
|
prefix = "decoder."
|
||||||
|
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
"""
|
|
||||||
State dict converter for ACE-Step LLM (Qwen3-based).
|
|
||||||
|
|
||||||
The safetensors file stores Qwen3 model weights. Different checkpoints
|
|
||||||
may have different key formats:
|
|
||||||
- Qwen3ForCausalLM format: model.embed_tokens.weight, model.layers.0.*
|
|
||||||
- Qwen3Model format: embed_tokens.weight, layers.0.*
|
|
||||||
|
|
||||||
Qwen3ForCausalLM wraps a .model attribute (Qwen3Model), so its
|
|
||||||
state_dict() has keys:
|
|
||||||
model.model.embed_tokens.weight
|
|
||||||
model.model.layers.0.self_attn.q_proj.weight
|
|
||||||
model.model.norm.weight
|
|
||||||
model.lm_head.weight (tied to model.model.embed_tokens)
|
|
||||||
|
|
||||||
This converter normalizes all keys to the Qwen3ForCausalLM format.
|
|
||||||
|
|
||||||
Example mapping:
|
|
||||||
model.embed_tokens.weight -> model.model.embed_tokens.weight
|
|
||||||
embed_tokens.weight -> model.model.embed_tokens.weight
|
|
||||||
model.layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
|
|
||||||
layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
|
|
||||||
model.norm.weight -> model.model.norm.weight
|
|
||||||
norm.weight -> model.model.norm.weight
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def ace_step_lm_converter(state_dict):
|
|
||||||
"""
|
|
||||||
Convert ACE-Step LLM checkpoint keys to match Qwen3ForCausalLM state dict.
|
|
||||||
|
|
||||||
参数 state_dict 是 DiskMap 类型。
|
|
||||||
遍历时,key 是 key 名,state_dict[key] 获取实际值。
|
|
||||||
"""
|
|
||||||
new_state_dict = {}
|
|
||||||
model_prefix = "model."
|
|
||||||
nested_prefix = "model.model."
|
|
||||||
|
|
||||||
for key in state_dict:
|
|
||||||
if key.startswith(nested_prefix):
|
|
||||||
# Already has model.model., keep as is
|
|
||||||
new_key = key
|
|
||||||
elif key.startswith(model_prefix):
|
|
||||||
# Has model., add another model.
|
|
||||||
new_key = "model." + key
|
|
||||||
else:
|
|
||||||
# No prefix, add model.model.
|
|
||||||
new_key = "model.model." + key
|
|
||||||
new_state_dict[new_key] = state_dict[key]
|
|
||||||
|
|
||||||
# Handle tied word embeddings: lm_head.weight shares with embed_tokens
|
|
||||||
if "model.model.embed_tokens.weight" in new_state_dict:
|
|
||||||
new_state_dict["model.lm_head.weight"] = new_state_dict["model.model.embed_tokens.weight"]
|
|
||||||
|
|
||||||
return new_state_dict
|
|
||||||
@@ -1,28 +1,4 @@
|
|||||||
"""
|
def AceStepTextEncoderStateDictConverter(state_dict):
|
||||||
State dict converter for ACE-Step Text Encoder (Qwen3-Embedding-0.6B).
|
|
||||||
|
|
||||||
The safetensors stores Qwen3Model weights with keys:
|
|
||||||
embed_tokens.weight
|
|
||||||
layers.0.self_attn.q_proj.weight
|
|
||||||
norm.weight
|
|
||||||
|
|
||||||
AceStepTextEncoder wraps a .model attribute (Qwen3Model), so its
|
|
||||||
state_dict() has keys with 'model.' prefix:
|
|
||||||
model.embed_tokens.weight
|
|
||||||
model.layers.0.self_attn.q_proj.weight
|
|
||||||
model.norm.weight
|
|
||||||
|
|
||||||
This converter adds 'model.' prefix to match the nested structure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def ace_step_text_encoder_converter(state_dict):
|
|
||||||
"""
|
|
||||||
Convert ACE-Step Text Encoder checkpoint keys to match Qwen3Model wrapped state dict.
|
|
||||||
|
|
||||||
参数 state_dict 是 DiskMap 类型。
|
|
||||||
遍历时,key 是 key 名,state_dict[key] 获取实际值。
|
|
||||||
"""
|
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
prefix = "model."
|
prefix = "model."
|
||||||
nested_prefix = "model.model."
|
nested_prefix = "model.model."
|
||||||
|
|||||||
@@ -1,23 +1,4 @@
|
|||||||
"""
|
def AceStepTokenizerStateDictConverter(state_dict):
|
||||||
State dict converter for ACE-Step Tokenizer model.
|
|
||||||
|
|
||||||
The original checkpoint stores tokenizer and detokenizer weights at the top level:
|
|
||||||
- tokenizer.* (AceStepAudioTokenizer: audio_acoustic_proj, attention_pooler, quantizer)
|
|
||||||
- detokenizer.* (AudioTokenDetokenizer: embed_tokens, layers, proj_out)
|
|
||||||
|
|
||||||
These map directly to the AceStepTokenizer class which wraps both as
|
|
||||||
self.tokenizer and self.detokenizer submodules.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def ace_step_tokenizer_converter(state_dict):
|
|
||||||
"""
|
|
||||||
Convert ACE-Step Tokenizer checkpoint keys to DiffSynth format.
|
|
||||||
|
|
||||||
The checkpoint keys `tokenizer.*` and `detokenizer.*` already match
|
|
||||||
the DiffSynth AceStepTokenizer module structure (self.tokenizer, self.detokenizer).
|
|
||||||
No key remapping needed — just extract the relevant keys.
|
|
||||||
"""
|
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
|
|
||||||
for key in state_dict:
|
for key in state_dict:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||||
from diffsynth.utils.data.audio import save_audio
|
from diffsynth.utils.data.audio import save_audio
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
pipe = AceStepPipeline.from_pretrained(
|
pipe = AceStepPipeline.from_pretrained(
|
||||||
@@ -29,6 +30,11 @@ audio = pipe(
|
|||||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
|
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
|
||||||
|
|
||||||
# input audio codes as reference
|
# input audio codes as reference
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt",
|
||||||
|
)
|
||||||
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
|
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
|
||||||
audio_code_string = f.read().strip()
|
audio_code_string = f.read().strip()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||||
from diffsynth.utils.data.audio import save_audio, read_audio
|
from diffsynth.utils.data.audio import save_audio, read_audio
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
pipe = AceStepPipeline.from_pretrained(
|
pipe = AceStepPipeline.from_pretrained(
|
||||||
@@ -15,6 +16,13 @@ pipe = AceStepPipeline.from_pretrained(
|
|||||||
|
|
||||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||||
|
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
|
||||||
|
)
|
||||||
|
|
||||||
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
|
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
|
||||||
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
|
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
|
||||||
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
|
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||||
from diffsynth.utils.data.audio import save_audio, read_audio
|
from diffsynth.utils.data.audio import save_audio, read_audio
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
pipe = AceStepPipeline.from_pretrained(
|
pipe = AceStepPipeline.from_pretrained(
|
||||||
@@ -15,6 +16,13 @@ pipe = AceStepPipeline.from_pretrained(
|
|||||||
|
|
||||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||||
|
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
|
||||||
|
)
|
||||||
|
|
||||||
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
|
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
|
||||||
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
|
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
|
||||||
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
|
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ Turbo model: uses num_inference_steps=8, cfg_scale=1.0.
|
|||||||
"""
|
"""
|
||||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||||
from diffsynth.utils.data.audio import save_audio
|
from diffsynth.utils.data.audio import save_audio
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -49,6 +50,11 @@ audio = pipe(
|
|||||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-low-vram.wav")
|
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-low-vram.wav")
|
||||||
|
|
||||||
# input audio codes as reference
|
# input audio codes as reference
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt",
|
||||||
|
)
|
||||||
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
|
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
|
||||||
audio_code_string = f.read().strip()
|
audio_code_string = f.read().strip()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||||
from diffsynth.utils.data.audio import save_audio, read_audio
|
from diffsynth.utils.data.audio import save_audio, read_audio
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
vram_config = {
|
vram_config = {
|
||||||
@@ -27,6 +28,13 @@ pipe = AceStepPipeline.from_pretrained(
|
|||||||
|
|
||||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||||
|
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
|
||||||
|
)
|
||||||
|
|
||||||
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
|
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
|
||||||
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
|
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
|
||||||
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
|
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||||
from diffsynth.utils.data.audio import save_audio, read_audio
|
from diffsynth.utils.data.audio import save_audio, read_audio
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
vram_config = {
|
vram_config = {
|
||||||
@@ -27,6 +28,13 @@ pipe = AceStepPipeline.from_pretrained(
|
|||||||
|
|
||||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||||
|
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
|
||||||
|
)
|
||||||
|
|
||||||
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
|
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
|
||||||
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
|
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
|
||||||
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
|
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
|
||||||
|
|||||||
@@ -1,31 +1,15 @@
|
|||||||
import torch, os, argparse, accelerate, warnings, torchaudio
|
import os
|
||||||
|
import torch
|
||||||
import math
|
import math
|
||||||
|
import argparse
|
||||||
|
import accelerate
|
||||||
from diffsynth.core import UnifiedDataset
|
from diffsynth.core import UnifiedDataset
|
||||||
from diffsynth.core.data.operators import ToAbsolutePath, RouteByType, DataProcessingOperator, LoadPureAudioWithTorchaudio
|
from diffsynth.core.data.operators import ToAbsolutePath, LoadPureAudioWithTorchaudio
|
||||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||||
from diffsynth.diffusion import *
|
from diffsynth.diffusion import *
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
class LoadAceStepAudio(DataProcessingOperator):
|
|
||||||
"""Load audio file and return waveform tensor [2, T] at 48kHz."""
|
|
||||||
def __init__(self, target_sr=48000):
|
|
||||||
self.target_sr = target_sr
|
|
||||||
|
|
||||||
def __call__(self, data: str):
|
|
||||||
try:
|
|
||||||
waveform, sample_rate = torchaudio.load(data)
|
|
||||||
if sample_rate != self.target_sr:
|
|
||||||
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sr)
|
|
||||||
waveform = resampler(waveform)
|
|
||||||
if waveform.shape[0] == 1:
|
|
||||||
waveform = waveform.repeat(2, 1)
|
|
||||||
return waveform
|
|
||||||
except Exception as e:
|
|
||||||
warnings.warn(f"Cannot load audio from {data}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class AceStepTrainingModule(DiffusionTrainingModule):
|
class AceStepTrainingModule(DiffusionTrainingModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -43,17 +27,15 @@ class AceStepTrainingModule(DiffusionTrainingModule):
|
|||||||
task="sft",
|
task="sft",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# ===== 解析模型配置(固定写法) =====
|
|
||||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||||
# ===== Tokenizer 配置 =====
|
|
||||||
text_tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"))
|
text_tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"))
|
||||||
silence_latent_config = self.parse_path_or_model_id(silence_latent_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"))
|
silence_latent_config = self.parse_path_or_model_id(silence_latent_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"))
|
||||||
# ===== 构建 Pipeline =====
|
self.pipe = AceStepPipeline.from_pretrained(
|
||||||
self.pipe = AceStepPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config)
|
torch_dtype=torch.bfloat16, device=device, model_configs=model_configs,
|
||||||
# ===== 拆分 Pipeline Units(固定写法) =====
|
text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config,
|
||||||
|
)
|
||||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
|
||||||
# ===== 切换到训练模式(固定写法) =====
|
|
||||||
self.switch_pipe_to_training_mode(
|
self.switch_pipe_to_training_mode(
|
||||||
self.pipe, trainable_models,
|
self.pipe, trainable_models,
|
||||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||||
@@ -61,13 +43,11 @@ class AceStepTrainingModule(DiffusionTrainingModule):
|
|||||||
task=task,
|
task=task,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ===== 其他配置(固定写法) =====
|
|
||||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
self.fp8_models = fp8_models
|
self.fp8_models = fp8_models
|
||||||
self.task = task
|
self.task = task
|
||||||
# ===== 任务模式路由(固定写法) =====
|
|
||||||
self.task_to_loss = {
|
self.task_to_loss = {
|
||||||
"sft:data_process": lambda pipe, *args: args,
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
@@ -78,11 +58,8 @@ class AceStepTrainingModule(DiffusionTrainingModule):
|
|||||||
inputs_posi = {"prompt": data["prompt"], "positive": True}
|
inputs_posi = {"prompt": data["prompt"], "positive": True}
|
||||||
inputs_nega = {"positive": False}
|
inputs_nega = {"positive": False}
|
||||||
duration = math.floor(data['audio'][0].shape[1] / data['audio'][1]) if data.get("audio") is not None else data.get("duration", 60)
|
duration = math.floor(data['audio'][0].shape[1] / data['audio'][1]) if data.get("audio") is not None else data.get("duration", 60)
|
||||||
# ===== 共享参数 =====
|
|
||||||
inputs_shared = {
|
inputs_shared = {
|
||||||
# ===== 核心字段映射 =====
|
|
||||||
"input_audio": data["audio"],
|
"input_audio": data["audio"],
|
||||||
# ===== 音频生成任务所需元数据 =====
|
|
||||||
"lyrics": data["lyrics"],
|
"lyrics": data["lyrics"],
|
||||||
"task_type": "text2music",
|
"task_type": "text2music",
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
@@ -90,18 +67,15 @@ class AceStepTrainingModule(DiffusionTrainingModule):
|
|||||||
"keyscale": data.get("keyscale", "C major"),
|
"keyscale": data.get("keyscale", "C major"),
|
||||||
"timesignature": data.get("timesignature", "4"),
|
"timesignature": data.get("timesignature", "4"),
|
||||||
"vocal_language": data.get("vocal_language", "unknown"),
|
"vocal_language": data.get("vocal_language", "unknown"),
|
||||||
# ===== 框架控制参数(固定写法) =====
|
|
||||||
"cfg_scale": 1,
|
"cfg_scale": 1,
|
||||||
"rand_device": self.pipe.device,
|
"rand_device": self.pipe.device,
|
||||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
}
|
}
|
||||||
# ===== 额外字段注入:通过 --extra_inputs 配置的数据集列名(固定写法) =====
|
|
||||||
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
return inputs_shared, inputs_posi, inputs_nega
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
def forward(self, data, inputs=None):
|
def forward(self, data, inputs=None):
|
||||||
# ===== 标准实现,不要修改(固定写法) =====
|
|
||||||
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
for unit in self.pipe.units:
|
for unit in self.pipe.units:
|
||||||
@@ -122,12 +96,10 @@ def ace_step_parser():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = ace_step_parser()
|
parser = ace_step_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# ===== Accelerator 配置(固定写法) =====
|
|
||||||
accelerator = accelerate.Accelerator(
|
accelerator = accelerate.Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||||
)
|
)
|
||||||
# ===== 数据集定义 =====
|
|
||||||
dataset = UnifiedDataset(
|
dataset = UnifiedDataset(
|
||||||
base_path=args.dataset_base_path,
|
base_path=args.dataset_base_path,
|
||||||
metadata_path=args.dataset_metadata_path,
|
metadata_path=args.dataset_metadata_path,
|
||||||
@@ -135,10 +107,11 @@ if __name__ == "__main__":
|
|||||||
data_file_keys=args.data_file_keys.split(","),
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
main_data_operator=None,
|
main_data_operator=None,
|
||||||
special_operator_map={
|
special_operator_map={
|
||||||
"audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(target_sample_rate=48000),
|
"audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(
|
||||||
|
target_sample_rate=48000,
|
||||||
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# ===== TrainingModule =====
|
|
||||||
model = AceStepTrainingModule(
|
model = AceStepTrainingModule(
|
||||||
model_paths=args.model_paths,
|
model_paths=args.model_paths,
|
||||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
@@ -159,12 +132,10 @@ if __name__ == "__main__":
|
|||||||
task=args.task,
|
task=args.task,
|
||||||
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
||||||
)
|
)
|
||||||
# ===== ModelLogger(固定写法) =====
|
|
||||||
model_logger = ModelLogger(
|
model_logger = ModelLogger(
|
||||||
args.output_path,
|
args.output_path,
|
||||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
)
|
)
|
||||||
# ===== 任务路由(固定写法) =====
|
|
||||||
launcher_map = {
|
launcher_map = {
|
||||||
"sft:data_process": launch_data_process_task,
|
"sft:data_process": launch_data_process_task,
|
||||||
"sft": launch_training_task,
|
"sft": launch_training_task,
|
||||||
|
|||||||
Reference in New Issue
Block a user