reference audio input

This commit is contained in:
mi804
2026-04-22 19:16:04 +08:00
parent c53c813c12
commit f2e3427566
2 changed files with 25 additions and 85 deletions

View File

@@ -506,7 +506,6 @@ class AceStepTimbreEncoder(nn.Module):
for layer_idx in range(num_timbre_encoder_hidden_layers)
])
def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask):
N, d = timbre_embs_packed.shape
device = timbre_embs_packed.device

View File

@@ -8,6 +8,7 @@ import torch
from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
import random
import math
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
@@ -293,107 +294,47 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
)
def process(self, pipe, reference_audios):
if reference_audios is not None:
pipe.load_models_to_device(['vae'])
if reference_audios is not None and len(reference_audios) > 0:
raise NotImplementedError("Reference audio embedding is not implemented yet.")
# TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask
pass
reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
else:
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
# def process_reference_audio(self, reference_audios) -> Optional[torch.Tensor]:
# try:
# audio_np, sr = _read_audio_file(audio_file)
# audio = self._numpy_to_channels_first(audio_np)
# logger.debug(
# f"[process_reference_audio] Reference audio shape: {audio.shape}"
# )
# logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
# logger.debug(
# f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / sr:.6f} seconds"
# )
# audio = self._normalize_audio_to_stereo_48k(audio, sr)
# if self.is_silence(audio):
# return None
# target_frames = 30 * 48000
# segment_frames = 10 * 48000
# if audio.shape[-1] < target_frames:
# repeat_times = math.ceil(target_frames / audio.shape[-1])
# audio = audio.repeat(1, repeat_times)
# total_frames = audio.shape[-1]
# segment_size = total_frames // 3
# front_start = random.randint(0, max(0, segment_size - segment_frames))
# front_audio = audio[:, front_start : front_start + segment_frames]
# middle_start = segment_size + random.randint(
# 0, max(0, segment_size - segment_frames)
# )
# middle_audio = audio[:, middle_start : middle_start + segment_frames]
# back_start = 2 * segment_size + random.randint(
# 0, max(0, (total_frames - 2 * segment_size) - segment_frames)
# )
# back_audio = audio[:, back_start : back_start + segment_frames]
# return torch.cat([front_audio, middle_audio, back_audio], dim=-1)
def process_reference_audio(self, audio) -> Optional[torch.Tensor]:
if audio.ndim == 3 and audio.shape[0] == 1:
audio = audio.squeeze(0)
target_frames = 30 * 48000
segment_frames = 10 * 48000
if audio.shape[-1] < target_frames:
repeat_times = math.ceil(target_frames / audio.shape[-1])
audio = audio.repeat(1, repeat_times)
total_frames = audio.shape[-1]
segment_size = total_frames // 3
front_start = random.randint(0, max(0, segment_size - segment_frames))
front_audio = audio[:, front_start:front_start + segment_frames]
middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
middle_audio = audio[:, middle_start:middle_start + segment_frames]
back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames))
back_audio = audio[:, back_start:back_start + segment_frames]
return torch.cat([front_audio, middle_audio, back_audio], dim=-1).unsqueeze(0)
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Infer packed reference-audio latents and order mask."""
refer_audio_order_mask = []
refer_audio_latents = []
def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
if not isinstance(a, torch.Tensor):
raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}")
if a.dim() == 3 and a.shape[0] == 1:
a = a.squeeze(0)
if a.dim() == 1:
a = a.unsqueeze(0)
if a.dim() != 2:
raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}")
if a.shape[0] == 1:
a = torch.cat([a, a], dim=0)
return a[:2]
def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor:
if z.dim() == 4 and z.shape[0] == 1:
z = z.squeeze(0)
if z.dim() == 2:
z = z.unsqueeze(0)
return z
refer_encode_cache: Dict[int, torch.Tensor] = {}
for batch_idx, refer_audios in enumerate(refer_audioss):
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
refer_audio_latent = _ensure_latent_3d(pipe.silence_latent[:, :750, :])
refer_audio_latent = pipe.silence_latent[:, :750, :]
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
else:
# TODO: check
for refer_audio in refer_audios:
cache_key = refer_audio.data_ptr()
if cache_key in refer_encode_cache:
refer_audio_latent = refer_encode_cache[cache_key].clone()
else:
refer_audio = _normalize_audio_2d(refer_audio)
refer_audio_latent = pipe.vae.encode(refer_audio)
refer_audio_latent = refer_audio_latent.to(dtype=pipe.torch_dtype, device=pipe.device)
if refer_audio_latent.dim() == 2:
refer_audio_latent = refer_audio_latent.unsqueeze(0)
refer_audio_latent = _ensure_latent_3d(refer_audio_latent.transpose(1, 2))
refer_encode_cache[cache_key] = refer_audio_latent
refer_audio_latent = pipe.vae.encode(refer_audio).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
return refer_audio_latents, refer_audio_order_mask