This commit is contained in:
mi804
2026-04-23 17:31:34 +08:00
parent 394db06d86
commit a80fb84220
14 changed files with 99 additions and 243 deletions

View File

@@ -1,31 +1,15 @@
import torch, os, argparse, accelerate, warnings, torchaudio
import os
import torch
import math
import argparse
import accelerate
from diffsynth.core import UnifiedDataset
from diffsynth.core.data.operators import ToAbsolutePath, RouteByType, DataProcessingOperator, LoadPureAudioWithTorchaudio
from diffsynth.core.data.operators import ToAbsolutePath, LoadPureAudioWithTorchaudio
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.diffusion import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class LoadAceStepAudio(DataProcessingOperator):
"""Load audio file and return waveform tensor [2, T] at 48kHz."""
def __init__(self, target_sr=48000):
self.target_sr = target_sr
def __call__(self, data: str):
try:
waveform, sample_rate = torchaudio.load(data)
if sample_rate != self.target_sr:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sr)
waveform = resampler(waveform)
if waveform.shape[0] == 1:
waveform = waveform.repeat(2, 1)
return waveform
except Exception as e:
warnings.warn(f"Cannot load audio from {data}: {e}")
return None
class AceStepTrainingModule(DiffusionTrainingModule):
def __init__(
self,
@@ -43,17 +27,15 @@ class AceStepTrainingModule(DiffusionTrainingModule):
task="sft",
):
super().__init__()
# ===== 解析模型配置(固定写法) =====
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
# ===== Tokenizer 配置 =====
text_tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"))
silence_latent_config = self.parse_path_or_model_id(silence_latent_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"))
# ===== 构建 Pipeline =====
self.pipe = AceStepPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config)
# ===== 拆分 Pipeline Units固定写法 =====
self.pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16, device=device, model_configs=model_configs,
text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config,
)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# ===== 切换到训练模式(固定写法) =====
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
@@ -61,13 +43,11 @@ class AceStepTrainingModule(DiffusionTrainingModule):
task=task,
)
# ===== 其他配置(固定写法) =====
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
# ===== 任务模式路由(固定写法) =====
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
@@ -78,11 +58,8 @@ class AceStepTrainingModule(DiffusionTrainingModule):
inputs_posi = {"prompt": data["prompt"], "positive": True}
inputs_nega = {"positive": False}
duration = math.floor(data['audio'][0].shape[1] / data['audio'][1]) if data.get("audio") is not None else data.get("duration", 60)
# ===== 共享参数 =====
inputs_shared = {
# ===== 核心字段映射 =====
"input_audio": data["audio"],
# ===== 音频生成任务所需元数据 =====
"lyrics": data["lyrics"],
"task_type": "text2music",
"duration": duration,
@@ -90,18 +67,15 @@ class AceStepTrainingModule(DiffusionTrainingModule):
"keyscale": data.get("keyscale", "C major"),
"timesignature": data.get("timesignature", "4"),
"vocal_language": data.get("vocal_language", "unknown"),
# ===== 框架控制参数(固定写法) =====
"cfg_scale": 1,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
# ===== 额外字段注入:通过 --extra_inputs 配置的数据集列名(固定写法) =====
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
# ===== 标准实现,不要修改(固定写法) =====
if inputs is None: inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
@@ -122,12 +96,10 @@ def ace_step_parser():
if __name__ == "__main__":
parser = ace_step_parser()
args = parser.parse_args()
# ===== Accelerator 配置(固定写法) =====
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
)
# ===== 数据集定义 =====
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
@@ -135,10 +107,11 @@ if __name__ == "__main__":
data_file_keys=args.data_file_keys.split(","),
main_data_operator=None,
special_operator_map={
"audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(target_sample_rate=48000),
"audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(
target_sample_rate=48000,
),
},
)
# ===== TrainingModule =====
model = AceStepTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
@@ -159,12 +132,10 @@ if __name__ == "__main__":
task=args.task,
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
)
# ===== ModelLogger固定写法 =====
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
# ===== 任务路由(固定写法) =====
launcher_map = {
"sft:data_process": launch_data_process_task,
"sft": launch_training_task,