reference audio input

This commit is contained in:
mi804
2026-04-22 19:16:04 +08:00
parent c53c813c12
commit f2e3427566
2 changed files with 25 additions and 85 deletions

View File

@@ -506,7 +506,6 @@ class AceStepTimbreEncoder(nn.Module):
for layer_idx in range(num_timbre_encoder_hidden_layers) for layer_idx in range(num_timbre_encoder_hidden_layers)
]) ])
def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask): def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask):
N, d = timbre_embs_packed.shape N, d = timbre_embs_packed.shape
device = timbre_embs_packed.device device = timbre_embs_packed.device

View File

@@ -8,6 +8,7 @@ import torch
from typing import Optional, Dict, Any, List, Tuple from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm from tqdm import tqdm
import random import random
import math
from ..core.device.npu_compatible_device import get_device_type from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
@@ -43,7 +44,7 @@ class AceStepPipeline(BasePipeline):
AceStepUnit_PromptEmbedder(), AceStepUnit_PromptEmbedder(),
AceStepUnit_ReferenceAudioEmbedder(), AceStepUnit_ReferenceAudioEmbedder(),
AceStepUnit_ConditionEmbedder(), AceStepUnit_ConditionEmbedder(),
AceStepUnit_AudioCodeDecoder(), AceStepUnit_AudioCodeDecoder(),
AceStepUnit_ContextLatentBuilder(), AceStepUnit_ContextLatentBuilder(),
AceStepUnit_NoiseInitializer(), AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(), AceStepUnit_InputAudioEmbedder(),
@@ -293,107 +294,47 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
) )
def process(self, pipe, reference_audios): def process(self, pipe, reference_audios):
pipe.load_models_to_device(['vae']) if reference_audios is not None:
if reference_audios is not None and len(reference_audios) > 0: pipe.load_models_to_device(['vae'])
raise NotImplementedError("Reference audio embedding is not implemented yet.") reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios]
# TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
pass
else: else:
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]] reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios) reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask} return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
# def process_reference_audio(self, reference_audios) -> Optional[torch.Tensor]: def process_reference_audio(self, audio) -> Optional[torch.Tensor]:
if audio.ndim == 3 and audio.shape[0] == 1:
# try: audio = audio.squeeze(0)
# audio_np, sr = _read_audio_file(audio_file) target_frames = 30 * 48000
# audio = self._numpy_to_channels_first(audio_np) segment_frames = 10 * 48000
if audio.shape[-1] < target_frames:
# logger.debug( repeat_times = math.ceil(target_frames / audio.shape[-1])
# f"[process_reference_audio] Reference audio shape: {audio.shape}" audio = audio.repeat(1, repeat_times)
# ) total_frames = audio.shape[-1]
# logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}") segment_size = total_frames // 3
# logger.debug( front_start = random.randint(0, max(0, segment_size - segment_frames))
# f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / sr:.6f} seconds" front_audio = audio[:, front_start:front_start + segment_frames]
# ) middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
middle_audio = audio[:, middle_start:middle_start + segment_frames]
# audio = self._normalize_audio_to_stereo_48k(audio, sr) back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames))
# if self.is_silence(audio): back_audio = audio[:, back_start:back_start + segment_frames]
# return None return torch.cat([front_audio, middle_audio, back_audio], dim=-1).unsqueeze(0)
# target_frames = 30 * 48000
# segment_frames = 10 * 48000
# if audio.shape[-1] < target_frames:
# repeat_times = math.ceil(target_frames / audio.shape[-1])
# audio = audio.repeat(1, repeat_times)
# total_frames = audio.shape[-1]
# segment_size = total_frames // 3
# front_start = random.randint(0, max(0, segment_size - segment_frames))
# front_audio = audio[:, front_start : front_start + segment_frames]
# middle_start = segment_size + random.randint(
# 0, max(0, segment_size - segment_frames)
# )
# middle_audio = audio[:, middle_start : middle_start + segment_frames]
# back_start = 2 * segment_size + random.randint(
# 0, max(0, (total_frames - 2 * segment_size) - segment_frames)
# )
# back_audio = audio[:, back_start : back_start + segment_frames]
# return torch.cat([front_audio, middle_audio, back_audio], dim=-1)
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]: def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Infer packed reference-audio latents and order mask.""" """Infer packed reference-audio latents and order mask."""
refer_audio_order_mask = [] refer_audio_order_mask = []
refer_audio_latents = [] refer_audio_latents = []
def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
if not isinstance(a, torch.Tensor):
raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}")
if a.dim() == 3 and a.shape[0] == 1:
a = a.squeeze(0)
if a.dim() == 1:
a = a.unsqueeze(0)
if a.dim() != 2:
raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}")
if a.shape[0] == 1:
a = torch.cat([a, a], dim=0)
return a[:2]
def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor:
if z.dim() == 4 and z.shape[0] == 1:
z = z.squeeze(0)
if z.dim() == 2:
z = z.unsqueeze(0)
return z
refer_encode_cache: Dict[int, torch.Tensor] = {}
for batch_idx, refer_audios in enumerate(refer_audioss): for batch_idx, refer_audios in enumerate(refer_audioss):
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0): if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
refer_audio_latent = _ensure_latent_3d(pipe.silence_latent[:, :750, :]) refer_audio_latent = pipe.silence_latent[:, :750, :]
refer_audio_latents.append(refer_audio_latent) refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx) refer_audio_order_mask.append(batch_idx)
else: else:
# TODO: check
for refer_audio in refer_audios: for refer_audio in refer_audios:
cache_key = refer_audio.data_ptr() refer_audio_latent = pipe.vae.encode(refer_audio).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
if cache_key in refer_encode_cache:
refer_audio_latent = refer_encode_cache[cache_key].clone()
else:
refer_audio = _normalize_audio_2d(refer_audio)
refer_audio_latent = pipe.vae.encode(refer_audio)
refer_audio_latent = refer_audio_latent.to(dtype=pipe.torch_dtype, device=pipe.device)
if refer_audio_latent.dim() == 2:
refer_audio_latent = refer_audio_latent.unsqueeze(0)
refer_audio_latent = _ensure_latent_3d(refer_audio_latent.transpose(1, 2))
refer_encode_cache[cache_key] = refer_audio_latent
refer_audio_latents.append(refer_audio_latent) refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx) refer_audio_order_mask.append(batch_idx)
refer_audio_latents = torch.cat(refer_audio_latents, dim=0) refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long) refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
return refer_audio_latents, refer_audio_order_mask return refer_audio_latents, refer_audio_order_mask