diff --git a/diffsynth/models/ace_step_conditioner.py b/diffsynth/models/ace_step_conditioner.py index 76cc502..1279d79 100644 --- a/diffsynth/models/ace_step_conditioner.py +++ b/diffsynth/models/ace_step_conditioner.py @@ -506,7 +506,6 @@ class AceStepTimbreEncoder(nn.Module): for layer_idx in range(num_timbre_encoder_hidden_layers) ]) - def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask): N, d = timbre_embs_packed.shape device = timbre_embs_packed.device diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py index 71d180c..4730592 100644 --- a/diffsynth/pipelines/ace_step.py +++ b/diffsynth/pipelines/ace_step.py @@ -8,6 +8,7 @@ import torch from typing import Optional, Dict, Any, List, Tuple from tqdm import tqdm import random +import math from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler @@ -43,7 +44,7 @@ class AceStepPipeline(BasePipeline): AceStepUnit_PromptEmbedder(), AceStepUnit_ReferenceAudioEmbedder(), AceStepUnit_ConditionEmbedder(), - AceStepUnit_AudioCodeDecoder(), + AceStepUnit_AudioCodeDecoder(), AceStepUnit_ContextLatentBuilder(), AceStepUnit_NoiseInitializer(), AceStepUnit_InputAudioEmbedder(), @@ -293,107 +294,47 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit): ) def process(self, pipe, reference_audios): - pipe.load_models_to_device(['vae']) - if reference_audios is not None and len(reference_audios) > 0: - raise NotImplementedError("Reference audio embedding is not implemented yet.") - # TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask - pass + if reference_audios is not None: + pipe.load_models_to_device(['vae']) + reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios] + reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios]) else: reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]] reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios) return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask} - # def process_reference_audio(self, reference_audios) -> Optional[torch.Tensor]: - - # try: - # audio_np, sr = _read_audio_file(audio_file) - # audio = self._numpy_to_channels_first(audio_np) - - # logger.debug( - # f"[process_reference_audio] Reference audio shape: {audio.shape}" - # ) - # logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}") - # logger.debug( - # f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / sr:.6f} seconds" - # ) - - # audio = self._normalize_audio_to_stereo_48k(audio, sr) - # if self.is_silence(audio): - # return None - - # target_frames = 30 * 48000 - # segment_frames = 10 * 48000 - - # if audio.shape[-1] < target_frames: - # repeat_times = math.ceil(target_frames / audio.shape[-1]) - # audio = audio.repeat(1, repeat_times) - - # total_frames = audio.shape[-1] - # segment_size = total_frames // 3 - - # front_start = random.randint(0, max(0, segment_size - segment_frames)) - # front_audio = audio[:, front_start : front_start + segment_frames] - - # middle_start = segment_size + random.randint( - # 0, max(0, segment_size - segment_frames) - # ) - # middle_audio = audio[:, middle_start : middle_start + segment_frames] - - # back_start = 2 * segment_size + random.randint( - # 0, max(0, (total_frames - 2 * segment_size) - segment_frames) - # ) - # back_audio = audio[:, back_start : back_start + segment_frames] - - # return torch.cat([front_audio, middle_audio, back_audio], dim=-1) + def process_reference_audio(self, audio) -> Optional[torch.Tensor]: + if audio.ndim == 3 and audio.shape[0] == 1: + audio = audio.squeeze(0) + target_frames = 30 * 48000 + segment_frames = 10 * 48000 + if audio.shape[-1] < target_frames: + repeat_times = math.ceil(target_frames / audio.shape[-1]) + audio = audio.repeat(1, repeat_times) + total_frames = audio.shape[-1] + segment_size = total_frames // 3 + front_start = random.randint(0, max(0, segment_size - segment_frames)) + front_audio = audio[:, front_start:front_start + segment_frames] + middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames)) + middle_audio = audio[:, middle_start:middle_start + segment_frames] + back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames)) + back_audio = audio[:, back_start:back_start + segment_frames] + return torch.cat([front_audio, middle_audio, back_audio], dim=-1).unsqueeze(0) def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]: """Infer packed reference-audio latents and order mask.""" refer_audio_order_mask = [] refer_audio_latents = [] - - def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor: - if not isinstance(a, torch.Tensor): - raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}") - if a.dim() == 3 and a.shape[0] == 1: - a = a.squeeze(0) - if a.dim() == 1: - a = a.unsqueeze(0) - if a.dim() != 2: - raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}") - if a.shape[0] == 1: - a = torch.cat([a, a], dim=0) - return a[:2] - - def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor: - if z.dim() == 4 and z.shape[0] == 1: - z = z.squeeze(0) - if z.dim() == 2: - z = z.unsqueeze(0) - return z - - refer_encode_cache: Dict[int, torch.Tensor] = {} for batch_idx, refer_audios in enumerate(refer_audioss): if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0): - refer_audio_latent = _ensure_latent_3d(pipe.silence_latent[:, :750, :]) + refer_audio_latent = pipe.silence_latent[:, :750, :] refer_audio_latents.append(refer_audio_latent) refer_audio_order_mask.append(batch_idx) else: - # TODO: check for refer_audio in refer_audios: - cache_key = refer_audio.data_ptr() - if cache_key in refer_encode_cache: - refer_audio_latent = refer_encode_cache[cache_key].clone() - else: - refer_audio = _normalize_audio_2d(refer_audio) - refer_audio_latent = pipe.vae.encode(refer_audio) - refer_audio_latent = refer_audio_latent.to(dtype=pipe.torch_dtype, device=pipe.device) - if refer_audio_latent.dim() == 2: - refer_audio_latent = refer_audio_latent.unsqueeze(0) - refer_audio_latent = _ensure_latent_3d(refer_audio_latent.transpose(1, 2)) - refer_encode_cache[cache_key] = refer_audio_latent + refer_audio_latent = pipe.vae.encode(refer_audio).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device) refer_audio_latents.append(refer_audio_latent) refer_audio_order_mask.append(batch_idx) - refer_audio_latents = torch.cat(refer_audio_latents, dim=0) refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long) return refer_audio_latents, refer_audio_order_mask