This commit is contained in:
mi804
2026-04-23 16:52:59 +08:00
parent 1186379139
commit 394db06d86
7 changed files with 212 additions and 20 deletions

View File

@@ -106,6 +106,9 @@ class AceStepPipeline(BasePipeline):
audio_cover_strength: float = 1.0,
# Audio codes
audio_code_string: Optional[str] = None,
# Inpainting
repainting_ranges: Optional[List[Tuple[float, float]]] = None,
repainting_strength: float = 1.0,
# Shape
duration: int = 60,
# Audio Meta
@@ -134,9 +137,8 @@ class AceStepPipeline(BasePipeline):
"lyrics": lyrics,
"task_type": task_type,
"reference_audios": reference_audios,
"src_audio": src_audio,
"audio_cover_strength": audio_cover_strength,
"audio_code_string": audio_code_string,
"src_audio": src_audio, "audio_cover_strength": audio_cover_strength, "audio_code_string": audio_code_string,
"repainting_ranges": repainting_ranges, "repainting_strength": repainting_strength,
"duration": duration,
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
"seed": seed,
@@ -162,9 +164,8 @@ class AceStepPipeline(BasePipeline):
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id,
)
inputs_shared["latents"] = self.step(
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
)
inputs_shared["latents"] = self.step(self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# 5. VAE 解码
self.load_models_to_device(['vae'])
@@ -201,12 +202,17 @@ class AceStepUnit_TaskTypeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
input_params=("task_type",),
input_params=("task_type", "src_audio", "repainting_ranges", "audio_code_string"),
output_params=("task_type",),
)
def process(self, pipe, task_type):
def process(self, pipe, task_type, src_audio, repainting_ranges, audio_code_string):
assert task_type in ["text2music", "cover", "repaint"], f"Unsupported task_type: {task_type}"
if task_type == "cover":
assert (src_audio is not None) or (audio_code_string is not None), "For cover task, either src_audio or audio_code_string must be provided."
elif task_type == "repaint":
assert src_audio is not None, "For repaint task, src_audio must be provided."
assert repainting_ranges is not None and len(repainting_ranges) > 0, "For repaint task, inpainting_ranges must be provided and non-empty."
return {}
@@ -399,7 +405,7 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("duration", "src_audio", "audio_code_string"),
input_params=("duration", "src_audio", "audio_code_string", "task_type", "repainting_ranges", "repainting_strength"),
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
onload_model_names=("vae", "tokenizer_model",),
)
@@ -435,9 +441,46 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
raise ValueError(f"Invalid audio_code_string format: {e}")
return codes
def process(self, pipe, duration, src_audio, audio_code_string):
def pad_src_audio(self, pipe, src_audio, task_type, repainting_ranges):
if task_type != "repaint" or repainting_ranges is None:
return src_audio, repainting_ranges, None, None
min_left = min([start for start, end in repainting_ranges])
max_right = max([end for start, end in repainting_ranges])
total_length = src_audio.shape[-1] // pipe.vae.sampling_rate
pad_left = max(0, -min_left)
pad_right = max(0, max_right - total_length)
if pad_left > 0 or pad_right > 0:
padding_frames_left, padding_frames_right = pad_left * pipe.vae.sampling_rate, pad_right * pipe.vae.sampling_rate
src_audio = F.pad(src_audio, (padding_frames_left, padding_frames_right), value=0.0)
repainting_ranges = [(start + pad_left, end + pad_left) for start, end in repainting_ranges]
return src_audio, repainting_ranges, pad_left, pad_right
def parse_repaint_masks(self, pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right):
if task_type != "repaint" or repainting_ranges is None:
return None, src_latents
# let repainting area be repainting_strength, non-repainting area be 0.0, and blend at the boundary with cf_frames.
max_latent_length = src_latents.shape[1]
denoise_mask = torch.zeros((1, max_latent_length, 1), dtype=pipe.torch_dtype, device=pipe.device)
for start, end in repainting_ranges:
start_frame = start * pipe.vae.sampling_rate // 1920
end_frame = end * pipe.vae.sampling_rate // 1920
denoise_mask[:, start_frame:end_frame, :] = repainting_strength
# set padding areas to 1.0 (full repaint) to avoid artifacts at the boundaries caused by padding
pad_left_frames = pad_left * pipe.vae.sampling_rate // 1920
pad_right_frames = pad_right * pipe.vae.sampling_rate // 1920
denoise_mask[:, :pad_left_frames, :] = 1
denoise_mask[:, max_latent_length - pad_right_frames:, :] = 1
silent_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
src_latents = src_latents * (1 - denoise_mask) + silent_latents * denoise_mask
return denoise_mask, src_latents
def process(self, pipe, duration, src_audio, audio_code_string, task_type=None, repainting_ranges=None, repainting_strength=None):
# get src_latents from audio_code_string > src_audio > silence
source_latents = None
denoise_mask = None
if audio_code_string is not None:
# use audio_cede_string to get src_latents.
pipe.load_models_to_device(self.onload_model_names)
code_ids = self._parse_audio_code_string(audio_code_string)
quantizer = pipe.tokenizer_model.tokenizer.quantizer
@@ -448,33 +491,42 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
src_latents = pipe.tokenizer_model.detokenizer(quantized).to(pipe.device)
max_latent_length = src_latents.shape[1]
elif src_audio is not None:
# use src_audio to get src_latents.
pipe.load_models_to_device(self.onload_model_names)
src_audio = src_audio.unsqueeze(0) if src_audio.dim() == 2 else src_audio
src_audio = torch.clamp(src_audio, -1.0, 1.0)
src_audio, repainting_ranges, pad_left, pad_right = self.pad_src_audio(pipe, src_audio, task_type, repainting_ranges)
src_latents = pipe.vae.encode(src_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
source_latents = src_latents # cache for potential use in audio inpainting tasks
denoise_mask, src_latents = self.parse_repaint_masks(pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right)
if task_type == "cover":
lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
max_latent_length = src_latents.shape[1]
else:
# use silence latents.
max_latent_length = int(duration * pipe.sample_rate // 1920)
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
return {"context_latents": context_latents, "attention_mask": attention_mask}
return {"context_latents": context_latents, "attention_mask": attention_mask, "src_latents": source_latents, "denoise_mask": denoise_mask}
class AceStepUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("context_latents", "seed", "rand_device"),
input_params=("context_latents", "seed", "rand_device", "src_latents"),
output_params=("noise",),
)
def process(self, pipe, context_latents, seed, rand_device):
def process(self, pipe, context_latents, seed, rand_device, src_latents):
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
noise = pipe.scheduler.add_noise(context_latents[:, :, :src_latents_shape[-1]], noise, timestep=pipe.scheduler.timesteps[0])
if src_latents is not None:
noise = pipe.scheduler.add_noise(src_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"noise": noise}
@@ -502,7 +554,6 @@ class AceStepUnit_InputAudioEmbedder(PipelineUnit):
return {"input_latents": input_latents}
def model_fn_ace_step(
dit: AceStepDiTModel,
latents=None,