This commit is contained in:
mi804
2026-04-23 17:31:34 +08:00
parent 394db06d86
commit a80fb84220
14 changed files with 99 additions and 243 deletions

View File

@@ -3,12 +3,10 @@ ACE-Step Pipeline for DiffSynth-Studio.
Text-to-Music generation pipeline using ACE-Step 1.5 model.
"""
import re
import torch
import re, torch
from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
import random
import math
import random, math
import torch.nn.functional as F
from einops import rearrange
@@ -39,7 +37,7 @@ class AceStepPipeline(BasePipeline):
self.conditioner: AceStepConditionEncoder = None
self.dit: AceStepDiTModel = None
self.vae: AceStepVAE = None
self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer)
self.tokenizer_model: AceStepTokenizer = None
self.in_iteration_models = ("dit",)
self.units = [
@@ -65,7 +63,6 @@ class AceStepPipeline(BasePipeline):
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
vram_limit: float = None,
):
"""Load pipeline from pretrained checkpoints."""
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
@@ -102,7 +99,7 @@ class AceStepPipeline(BasePipeline):
reference_audios: List[torch.Tensor] = None,
# Source audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
audio_cover_strength: float = 1.0,
# Audio codes
audio_code_string: Optional[str] = None,
@@ -115,7 +112,7 @@ class AceStepPipeline(BasePipeline):
bpm: Optional[int] = 100,
keyscale: Optional[str] = "B minor",
timesignature: Optional[str] = "4",
vocal_language: Optional[str] = 'unknown',
vocal_language: Optional[str] = "unknown",
# Randomness
seed: int = None,
rand_device: str = "cpu",
@@ -126,10 +123,10 @@ class AceStepPipeline(BasePipeline):
# Progress
progress_bar_cmd=tqdm,
):
# 1. Scheduler
# Scheduler
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
# 2. 三字典输入
# Parameters
inputs_posi = {"prompt": prompt, "positive": True}
inputs_nega = {"positive": False}
inputs_shared = {
@@ -147,13 +144,12 @@ class AceStepPipeline(BasePipeline):
"shift": shift,
}
# 3. Unit 链执行
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# 4. Denoise loop
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
@@ -164,17 +160,17 @@ class AceStepPipeline(BasePipeline):
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id,
)
inputs_shared["latents"] = self.step(self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
inputs_shared["latents"] = self.step(
self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared,
)
# 5. VAE 解码
# Decode
self.load_models_to_device(['vae'])
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
latents = inputs_shared["latents"].transpose(1, 2)
vae_output = self.vae.decode(latents)
# VAE returns OobleckDecoderOutput with .sample attribute
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
audio_output = self.normalize_audio(audio_output, target_db=-1.0)
audio_output = self.normalize_audio(vae_output, target_db=-1.0)
audio = self.output_audio_format_check(audio_output)
self.load_models_to_device([])
return audio
@@ -188,7 +184,9 @@ class AceStepPipeline(BasePipeline):
return audio * gain
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0 or inputs_shared.get("shared_noncover", None) is None:
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0:
return
if inputs_shared.get("shared_noncover", None) is None:
return
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
if progress_id >= cover_steps:
@@ -312,7 +310,10 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
def process(self, pipe, reference_audios):
if reference_audios is not None:
pipe.load_models_to_device(['vae'])
reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios]
reference_audios = [
self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device)
for reference_audio in reference_audios
]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
else:
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
@@ -357,7 +358,6 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
class AceStepUnit_ConditionEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
@@ -378,7 +378,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device,
)
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
@@ -396,7 +398,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
if inputs_shared["cfg_scale"] != 1.0:
inputs_shared["nega_noncover"] = {
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device),
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(
dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device,
),
"encoder_attention_mask": encoder_attention_mask_noncover,
}
return inputs_shared, inputs_posi, inputs_nega
@@ -483,7 +487,7 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
# use audio_cede_string to get src_latents.
pipe.load_models_to_device(self.onload_model_names)
code_ids = self._parse_audio_code_string(audio_code_string)
quantizer = pipe.tokenizer_model.tokenizer.quantizer
quantizer = pipe.tokenizer_model.tokenizer.quantizer.to(device=pipe.device)
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
codes = quantizer.get_codes_from_indices(indices)
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)