mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 06:46:13 +00:00
reference audio input
This commit is contained in:
@@ -8,6 +8,7 @@ import torch
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import math
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
@@ -43,7 +44,7 @@ class AceStepPipeline(BasePipeline):
|
||||
AceStepUnit_PromptEmbedder(),
|
||||
AceStepUnit_ReferenceAudioEmbedder(),
|
||||
AceStepUnit_ConditionEmbedder(),
|
||||
AceStepUnit_AudioCodeDecoder(),
|
||||
AceStepUnit_AudioCodeDecoder(),
|
||||
AceStepUnit_ContextLatentBuilder(),
|
||||
AceStepUnit_NoiseInitializer(),
|
||||
AceStepUnit_InputAudioEmbedder(),
|
||||
@@ -293,107 +294,47 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
||||
)
|
||||
|
||||
def process(self, pipe, reference_audios):
|
||||
pipe.load_models_to_device(['vae'])
|
||||
if reference_audios is not None and len(reference_audios) > 0:
|
||||
raise NotImplementedError("Reference audio embedding is not implemented yet.")
|
||||
# TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask
|
||||
pass
|
||||
if reference_audios is not None:
|
||||
pipe.load_models_to_device(['vae'])
|
||||
reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios]
|
||||
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
|
||||
else:
|
||||
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
|
||||
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
|
||||
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
|
||||
|
||||
# def process_reference_audio(self, reference_audios) -> Optional[torch.Tensor]:
|
||||
|
||||
# try:
|
||||
# audio_np, sr = _read_audio_file(audio_file)
|
||||
# audio = self._numpy_to_channels_first(audio_np)
|
||||
|
||||
# logger.debug(
|
||||
# f"[process_reference_audio] Reference audio shape: {audio.shape}"
|
||||
# )
|
||||
# logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
|
||||
# logger.debug(
|
||||
# f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / sr:.6f} seconds"
|
||||
# )
|
||||
|
||||
# audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
||||
# if self.is_silence(audio):
|
||||
# return None
|
||||
|
||||
# target_frames = 30 * 48000
|
||||
# segment_frames = 10 * 48000
|
||||
|
||||
# if audio.shape[-1] < target_frames:
|
||||
# repeat_times = math.ceil(target_frames / audio.shape[-1])
|
||||
# audio = audio.repeat(1, repeat_times)
|
||||
|
||||
# total_frames = audio.shape[-1]
|
||||
# segment_size = total_frames // 3
|
||||
|
||||
# front_start = random.randint(0, max(0, segment_size - segment_frames))
|
||||
# front_audio = audio[:, front_start : front_start + segment_frames]
|
||||
|
||||
# middle_start = segment_size + random.randint(
|
||||
# 0, max(0, segment_size - segment_frames)
|
||||
# )
|
||||
# middle_audio = audio[:, middle_start : middle_start + segment_frames]
|
||||
|
||||
# back_start = 2 * segment_size + random.randint(
|
||||
# 0, max(0, (total_frames - 2 * segment_size) - segment_frames)
|
||||
# )
|
||||
# back_audio = audio[:, back_start : back_start + segment_frames]
|
||||
|
||||
# return torch.cat([front_audio, middle_audio, back_audio], dim=-1)
|
||||
def process_reference_audio(self, audio) -> Optional[torch.Tensor]:
|
||||
if audio.ndim == 3 and audio.shape[0] == 1:
|
||||
audio = audio.squeeze(0)
|
||||
target_frames = 30 * 48000
|
||||
segment_frames = 10 * 48000
|
||||
if audio.shape[-1] < target_frames:
|
||||
repeat_times = math.ceil(target_frames / audio.shape[-1])
|
||||
audio = audio.repeat(1, repeat_times)
|
||||
total_frames = audio.shape[-1]
|
||||
segment_size = total_frames // 3
|
||||
front_start = random.randint(0, max(0, segment_size - segment_frames))
|
||||
front_audio = audio[:, front_start:front_start + segment_frames]
|
||||
middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
|
||||
middle_audio = audio[:, middle_start:middle_start + segment_frames]
|
||||
back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames))
|
||||
back_audio = audio[:, back_start:back_start + segment_frames]
|
||||
return torch.cat([front_audio, middle_audio, back_audio], dim=-1).unsqueeze(0)
|
||||
|
||||
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Infer packed reference-audio latents and order mask."""
|
||||
refer_audio_order_mask = []
|
||||
refer_audio_latents = []
|
||||
|
||||
def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
|
||||
if not isinstance(a, torch.Tensor):
|
||||
raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}")
|
||||
if a.dim() == 3 and a.shape[0] == 1:
|
||||
a = a.squeeze(0)
|
||||
if a.dim() == 1:
|
||||
a = a.unsqueeze(0)
|
||||
if a.dim() != 2:
|
||||
raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}")
|
||||
if a.shape[0] == 1:
|
||||
a = torch.cat([a, a], dim=0)
|
||||
return a[:2]
|
||||
|
||||
def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor:
|
||||
if z.dim() == 4 and z.shape[0] == 1:
|
||||
z = z.squeeze(0)
|
||||
if z.dim() == 2:
|
||||
z = z.unsqueeze(0)
|
||||
return z
|
||||
|
||||
refer_encode_cache: Dict[int, torch.Tensor] = {}
|
||||
for batch_idx, refer_audios in enumerate(refer_audioss):
|
||||
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
|
||||
refer_audio_latent = _ensure_latent_3d(pipe.silence_latent[:, :750, :])
|
||||
refer_audio_latent = pipe.silence_latent[:, :750, :]
|
||||
refer_audio_latents.append(refer_audio_latent)
|
||||
refer_audio_order_mask.append(batch_idx)
|
||||
else:
|
||||
# TODO: check
|
||||
for refer_audio in refer_audios:
|
||||
cache_key = refer_audio.data_ptr()
|
||||
if cache_key in refer_encode_cache:
|
||||
refer_audio_latent = refer_encode_cache[cache_key].clone()
|
||||
else:
|
||||
refer_audio = _normalize_audio_2d(refer_audio)
|
||||
refer_audio_latent = pipe.vae.encode(refer_audio)
|
||||
refer_audio_latent = refer_audio_latent.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if refer_audio_latent.dim() == 2:
|
||||
refer_audio_latent = refer_audio_latent.unsqueeze(0)
|
||||
refer_audio_latent = _ensure_latent_3d(refer_audio_latent.transpose(1, 2))
|
||||
refer_encode_cache[cache_key] = refer_audio_latent
|
||||
refer_audio_latent = pipe.vae.encode(refer_audio).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
refer_audio_latents.append(refer_audio_latent)
|
||||
refer_audio_order_mask.append(batch_idx)
|
||||
|
||||
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
|
||||
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
|
||||
return refer_audio_latents, refer_audio_order_mask
|
||||
|
||||
Reference in New Issue
Block a user