mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
530 lines
25 KiB
Python
530 lines
25 KiB
Python
"""
|
|
ACE-Step Pipeline for DiffSynth-Studio.
|
|
|
|
Text-to-Music generation pipeline using ACE-Step 1.5 model.
|
|
"""
|
|
import re
|
|
import torch
|
|
from typing import Optional, Dict, Any, List, Tuple
|
|
from tqdm import tqdm
|
|
import random
|
|
import math
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
|
|
from ..core.device.npu_compatible_device import get_device_type
|
|
from ..diffusion import FlowMatchScheduler
|
|
from ..core import ModelConfig
|
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
|
|
|
from ..models.ace_step_dit import AceStepDiTModel
|
|
from ..models.ace_step_conditioner import AceStepConditionEncoder
|
|
from ..models.ace_step_text_encoder import AceStepTextEncoder
|
|
from ..models.ace_step_vae import AceStepVAE
|
|
from ..models.ace_step_tokenizer import AceStepTokenizer
|
|
|
|
|
|
class AceStepPipeline(BasePipeline):
|
|
"""Pipeline for ACE-Step text-to-music generation."""
|
|
|
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
|
super().__init__(
|
|
device=device,
|
|
torch_dtype=torch_dtype,
|
|
height_division_factor=1,
|
|
width_division_factor=1,
|
|
)
|
|
self.scheduler = FlowMatchScheduler("ACE-Step")
|
|
self.text_encoder: AceStepTextEncoder = None
|
|
self.conditioner: AceStepConditionEncoder = None
|
|
self.dit: AceStepDiTModel = None
|
|
self.vae: AceStepVAE = None
|
|
self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer)
|
|
|
|
self.in_iteration_models = ("dit",)
|
|
self.units = [
|
|
AceStepUnit_TaskTypeChecker(),
|
|
AceStepUnit_PromptEmbedder(),
|
|
AceStepUnit_ReferenceAudioEmbedder(),
|
|
AceStepUnit_ContextLatentBuilder(),
|
|
AceStepUnit_ConditionEmbedder(),
|
|
AceStepUnit_NoiseInitializer(),
|
|
AceStepUnit_InputAudioEmbedder(),
|
|
]
|
|
self.model_fn = model_fn_ace_step
|
|
self.compilable_models = ["dit"]
|
|
|
|
self.sample_rate = 48000
|
|
|
|
@staticmethod
|
|
def from_pretrained(
|
|
torch_dtype: torch.dtype = torch.bfloat16,
|
|
device: str = get_device_type(),
|
|
model_configs: list[ModelConfig] = [],
|
|
text_tokenizer_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
|
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
|
vram_limit: float = None,
|
|
):
|
|
"""Load pipeline from pretrained checkpoints."""
|
|
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
|
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
|
|
|
pipe.text_encoder = model_pool.fetch_model("ace_step_text_encoder")
|
|
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
|
|
pipe.dit = model_pool.fetch_model("ace_step_dit")
|
|
pipe.vae = model_pool.fetch_model("ace_step_vae")
|
|
pipe.vae.remove_weight_norm()
|
|
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
|
|
|
|
if text_tokenizer_config is not None:
|
|
text_tokenizer_config.download_if_necessary()
|
|
from transformers import AutoTokenizer
|
|
pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
|
|
if silence_latent_config is not None:
|
|
silence_latent_config.download_if_necessary()
|
|
pipe.silence_latent = torch.load(silence_latent_config.path, weights_only=True).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
|
|
# VRAM Management
|
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
|
return pipe
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
# Prompt
|
|
prompt: str,
|
|
cfg_scale: float = 1.0,
|
|
# Lyrics
|
|
lyrics: str = "",
|
|
# Task type
|
|
task_type: Optional[str] = "text2music",
|
|
# Reference audio
|
|
reference_audios: List[torch.Tensor] = None,
|
|
# Source audio
|
|
src_audio: torch.Tensor = None,
|
|
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
|
|
audio_cover_strength: float = 1.0,
|
|
# Audio codes
|
|
audio_code_string: Optional[str] = None,
|
|
# Shape
|
|
duration: int = 60,
|
|
# Audio Meta
|
|
bpm: Optional[int] = 100,
|
|
keyscale: Optional[str] = "B minor",
|
|
timesignature: Optional[str] = "4",
|
|
vocal_language: Optional[str] = 'unknown',
|
|
# Randomness
|
|
seed: int = None,
|
|
rand_device: str = "cpu",
|
|
# Steps
|
|
num_inference_steps: int = 8,
|
|
# Scheduler-specific parameters
|
|
shift: float = 1.0,
|
|
# Progress
|
|
progress_bar_cmd=tqdm,
|
|
):
|
|
# 1. Scheduler
|
|
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
|
|
|
|
# 2. 三字典输入
|
|
inputs_posi = {"prompt": prompt, "positive": True}
|
|
inputs_nega = {"positive": False}
|
|
inputs_shared = {
|
|
"cfg_scale": cfg_scale,
|
|
"lyrics": lyrics,
|
|
"task_type": task_type,
|
|
"reference_audios": reference_audios,
|
|
"src_audio": src_audio,
|
|
"audio_cover_strength": audio_cover_strength,
|
|
"audio_code_string": audio_code_string,
|
|
"duration": duration,
|
|
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
|
|
"seed": seed,
|
|
"rand_device": rand_device,
|
|
"num_inference_steps": num_inference_steps,
|
|
"shift": shift,
|
|
}
|
|
|
|
# 3. Unit 链执行
|
|
for unit in self.units:
|
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
|
unit, self, inputs_shared, inputs_posi, inputs_nega
|
|
)
|
|
|
|
# 4. Denoise loop
|
|
self.load_models_to_device(self.in_iteration_models)
|
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
self.switch_noncover_condition(inputs_shared, inputs_posi, inputs_nega, progress_id)
|
|
noise_pred = self.cfg_guided_model_fn(
|
|
self.model_fn, cfg_scale,
|
|
inputs_shared, inputs_posi, inputs_nega,
|
|
**models, timestep=timestep, progress_id=progress_id,
|
|
)
|
|
inputs_shared["latents"] = self.step(
|
|
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
|
|
)
|
|
|
|
# 5. VAE 解码
|
|
self.load_models_to_device(['vae'])
|
|
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
|
|
latents = inputs_shared["latents"].transpose(1, 2)
|
|
vae_output = self.vae.decode(latents)
|
|
# VAE returns OobleckDecoderOutput with .sample attribute
|
|
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
|
|
audio_output = self.normalize_audio(audio_output, target_db=-1.0)
|
|
audio = self.output_audio_format_check(audio_output)
|
|
self.load_models_to_device([])
|
|
return audio
|
|
|
|
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
|
|
peak = torch.max(torch.abs(audio))
|
|
if peak < 1e-6:
|
|
return audio
|
|
target_amp = 10 ** (target_db / 20.0)
|
|
gain = target_amp / peak
|
|
return audio * gain
|
|
|
|
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
|
|
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0 or inputs_shared.get("shared_noncover", None) is None:
|
|
return
|
|
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
|
|
if progress_id >= cover_steps:
|
|
inputs_shared.update(inputs_shared.pop("shared_noncover", {}))
|
|
inputs_posi.update(inputs_shared.pop("posi_noncover", {}))
|
|
if inputs_shared["cfg_scale"] != 1.0:
|
|
inputs_nega.update(inputs_shared.pop("nega_noncover", {}))
|
|
|
|
|
|
class AceStepUnit_TaskTypeChecker(PipelineUnit):
|
|
"""Check and compute sequence length from duration."""
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("task_type",),
|
|
output_params=("task_type",),
|
|
)
|
|
|
|
def process(self, pipe, task_type):
|
|
assert task_type in ["text2music", "cover", "repaint"], f"Unsupported task_type: {task_type}"
|
|
return {}
|
|
|
|
|
|
class AceStepUnit_PromptEmbedder(PipelineUnit):
|
|
SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
|
|
INSTRUCTION_MAP = {
|
|
"text2music": "Fill the audio semantic mask based on the given conditions:",
|
|
"cover": "Generate audio semantic tokens based on the given conditions:",
|
|
"repaint": "Repaint the mask area based on the given conditions:",
|
|
"extract": "Extract the {TRACK_NAME} track from the audio:",
|
|
"extract_default": "Extract the track from the audio:",
|
|
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
|
|
"lego_default": "Generate the track based on the audio context:",
|
|
"complete": "Complete the input track with {TRACK_CLASSES}:",
|
|
"complete_default": "Complete the input track:",
|
|
}
|
|
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
|
|
|
|
def __init__(self):
|
|
super().__init__(
|
|
seperate_cfg=True,
|
|
input_params_posi={"prompt": "prompt", "positive": "positive"},
|
|
input_params_nega={"prompt": "prompt", "positive": "positive"},
|
|
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language", "task_type"),
|
|
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
|
|
onload_model_names=("text_encoder",)
|
|
)
|
|
|
|
def _encode_text(self, pipe, text, max_length=256):
|
|
"""Encode text using Qwen3-Embedding → [B, T, 1024]."""
|
|
text_inputs = pipe.tokenizer(
|
|
text,
|
|
max_length=max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
input_ids = text_inputs.input_ids.to(pipe.device)
|
|
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
|
|
hidden_states = pipe.text_encoder(input_ids, attention_mask)
|
|
return hidden_states, attention_mask
|
|
|
|
def _encode_lyrics(self, pipe, lyric_text, max_length=2048):
|
|
text_inputs = pipe.tokenizer(
|
|
lyric_text,
|
|
max_length=max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
input_ids = text_inputs.input_ids.to(pipe.device)
|
|
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
|
|
hidden_states = pipe.text_encoder.model.embed_tokens(input_ids)
|
|
return hidden_states, attention_mask
|
|
|
|
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
|
|
bpm = meta_dict.get("bpm", "N/A")
|
|
timesignature = meta_dict.get("timesignature", "N/A")
|
|
keyscale = meta_dict.get("keyscale", "N/A")
|
|
duration = meta_dict.get("duration", 30)
|
|
duration = f"{int(duration)} seconds"
|
|
return (
|
|
f"- bpm: {bpm}\n"
|
|
f"- timesignature: {timesignature}\n"
|
|
f"- keyscale: {keyscale}\n"
|
|
f"- duration: {duration}\n"
|
|
)
|
|
|
|
def process(self, pipe, prompt, positive, lyrics, duration, bpm, keyscale, timesignature, vocal_language, task_type):
|
|
if not positive:
|
|
return {}
|
|
pipe.load_models_to_device(['text_encoder'])
|
|
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
|
|
INSTRUCTION = self.INSTRUCTION_MAP.get(task_type, self.INSTRUCTION_MAP["text2music"])
|
|
prompt = self.SFT_GEN_PROMPT.format(INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
|
|
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
|
|
|
|
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
|
|
lyric_hidden_states, lyric_attention_mask = self._encode_lyrics(pipe, lyric_text, max_length=2048)
|
|
|
|
# TODO: remove this
|
|
newtext = prompt + "\n\n" + lyric_text
|
|
return {
|
|
"text_hidden_states": text_hidden_states,
|
|
"text_attention_mask": text_attention_mask,
|
|
"lyric_hidden_states": lyric_hidden_states,
|
|
"lyric_attention_mask": lyric_attention_mask,
|
|
}
|
|
|
|
|
|
class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("reference_audios",),
|
|
output_params=("reference_latents", "refer_audio_order_mask"),
|
|
onload_model_names=("vae",)
|
|
)
|
|
|
|
def process(self, pipe, reference_audios):
|
|
if reference_audios is not None:
|
|
pipe.load_models_to_device(['vae'])
|
|
reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios]
|
|
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
|
|
else:
|
|
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
|
|
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
|
|
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
|
|
|
|
def process_reference_audio(self, audio) -> Optional[torch.Tensor]:
|
|
if audio.ndim == 3 and audio.shape[0] == 1:
|
|
audio = audio.squeeze(0)
|
|
target_frames = 30 * 48000
|
|
segment_frames = 10 * 48000
|
|
if audio.shape[-1] < target_frames:
|
|
repeat_times = math.ceil(target_frames / audio.shape[-1])
|
|
audio = audio.repeat(1, repeat_times)
|
|
total_frames = audio.shape[-1]
|
|
segment_size = total_frames // 3
|
|
front_start = random.randint(0, max(0, segment_size - segment_frames))
|
|
front_audio = audio[:, front_start:front_start + segment_frames]
|
|
middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
|
|
middle_audio = audio[:, middle_start:middle_start + segment_frames]
|
|
back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames))
|
|
back_audio = audio[:, back_start:back_start + segment_frames]
|
|
return torch.cat([front_audio, middle_audio, back_audio], dim=-1).unsqueeze(0)
|
|
|
|
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Infer packed reference-audio latents and order mask."""
|
|
refer_audio_order_mask = []
|
|
refer_audio_latents = []
|
|
for batch_idx, refer_audios in enumerate(refer_audioss):
|
|
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
|
|
refer_audio_latent = pipe.silence_latent[:, :750, :]
|
|
refer_audio_latents.append(refer_audio_latent)
|
|
refer_audio_order_mask.append(batch_idx)
|
|
else:
|
|
for refer_audio in refer_audios:
|
|
refer_audio_latent = pipe.vae.encode(refer_audio).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
refer_audio_latents.append(refer_audio_latent)
|
|
refer_audio_order_mask.append(batch_idx)
|
|
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
|
|
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
|
|
return refer_audio_latents, refer_audio_order_mask
|
|
|
|
|
|
class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
|
|
|
def __init__(self):
|
|
super().__init__(
|
|
take_over=True,
|
|
output_params=("encoder_hidden_states", "encoder_attention_mask"),
|
|
onload_model_names=("conditioner",),
|
|
)
|
|
|
|
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
|
|
pipe.load_models_to_device(['conditioner'])
|
|
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
|
|
text_hidden_states=inputs_posi.get("text_hidden_states", None),
|
|
text_attention_mask=inputs_posi.get("text_attention_mask", None),
|
|
lyric_hidden_states=inputs_posi.get("lyric_hidden_states", None),
|
|
lyric_attention_mask=inputs_posi.get("lyric_attention_mask", None),
|
|
reference_latents=inputs_shared.get("reference_latents", None),
|
|
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
|
|
)
|
|
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
|
|
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
|
|
if inputs_shared["cfg_scale"] != 1.0:
|
|
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
|
|
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
|
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
|
|
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
|
|
pipe, inputs_posi["prompt"], True, inputs_shared["lyrics"], inputs_shared["duration"],
|
|
inputs_shared["bpm"], inputs_shared["keyscale"], inputs_shared["timesignature"],
|
|
inputs_shared["vocal_language"], "text2music")
|
|
encoder_hidden_states_noncover, encoder_attention_mask_noncover = pipe.conditioner(
|
|
**hidden_states_noncover,
|
|
reference_latents=inputs_shared.get("reference_latents", None),
|
|
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
|
|
)
|
|
duration = inputs_shared["context_latents"].shape[1] * 1920 / pipe.vae.sampling_rate
|
|
context_latents_noncover = AceStepUnit_ContextLatentBuilder().process(pipe, duration, None, None)["context_latents"]
|
|
inputs_shared["shared_noncover"] = {"context_latents": context_latents_noncover}
|
|
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
|
|
if inputs_shared["cfg_scale"] != 1.0:
|
|
inputs_shared["nega_noncover"] = {
|
|
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device),
|
|
"encoder_attention_mask": encoder_attention_mask_noncover,
|
|
}
|
|
return inputs_shared, inputs_posi, inputs_nega
|
|
|
|
|
|
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("duration", "src_audio", "audio_code_string"),
|
|
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
|
|
onload_model_names=("vae", "tokenizer_model",),
|
|
)
|
|
|
|
def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
|
|
available = pipe.silence_latent.shape[1]
|
|
if length <= available:
|
|
return pipe.silence_latent[0, :length, :]
|
|
repeats = (length + available - 1) // available
|
|
tiled = pipe.silence_latent[0].repeat(repeats, 1)
|
|
return tiled[:length, :]
|
|
|
|
def tokenize(self, tokenizer, x, silence_latent, pool_window_size):
|
|
if x.shape[1] % pool_window_size != 0:
|
|
pad_len = pool_window_size - (x.shape[1] % pool_window_size)
|
|
x = torch.cat([x, silence_latent[:1,:pad_len].repeat(x.shape[0],1,1)], dim=1)
|
|
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=pool_window_size)
|
|
quantized, indices = tokenizer(x)
|
|
return quantized
|
|
|
|
@staticmethod
|
|
def _parse_audio_code_string(code_str: str) -> list:
|
|
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
|
|
if not code_str:
|
|
return []
|
|
try:
|
|
codes = []
|
|
max_audio_code = 63999
|
|
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
|
|
code_value = int(x)
|
|
codes.append(max(0, min(code_value, max_audio_code)))
|
|
except Exception as e:
|
|
raise ValueError(f"Invalid audio_code_string format: {e}")
|
|
return codes
|
|
|
|
def process(self, pipe, duration, src_audio, audio_code_string):
|
|
# get src_latents from audio_code_string > src_audio > silence
|
|
if audio_code_string is not None:
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
code_ids = self._parse_audio_code_string(audio_code_string)
|
|
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
|
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
|
|
codes = quantizer.get_codes_from_indices(indices)
|
|
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
|
|
quantized = quantizer.project_out(quantized)
|
|
src_latents = pipe.tokenizer_model.detokenizer(quantized).to(pipe.device)
|
|
max_latent_length = src_latents.shape[1]
|
|
elif src_audio is not None:
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
src_audio = src_audio.unsqueeze(0) if src_audio.dim() == 2 else src_audio
|
|
src_audio = torch.clamp(src_audio, -1.0, 1.0)
|
|
src_latents = pipe.vae.encode(src_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
|
|
lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
|
|
src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
|
|
max_latent_length = src_latents.shape[1]
|
|
else:
|
|
max_latent_length = int(duration * pipe.sample_rate // 1920)
|
|
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
|
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
|
|
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
|
|
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
|
return {"context_latents": context_latents, "attention_mask": attention_mask}
|
|
|
|
|
|
class AceStepUnit_NoiseInitializer(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("context_latents", "seed", "rand_device"),
|
|
output_params=("noise",),
|
|
)
|
|
|
|
def process(self, pipe, context_latents, seed, rand_device):
|
|
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
|
|
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
|
noise = pipe.scheduler.add_noise(context_latents[:, :, :src_latents_shape[-1]], noise, timestep=pipe.scheduler.timesteps[0])
|
|
return {"noise": noise}
|
|
|
|
|
|
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
|
"""Only for training."""
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("noise", "input_audio"),
|
|
output_params=("latents", "input_latents"),
|
|
onload_model_names=("vae",),
|
|
)
|
|
|
|
def process(self, pipe, noise, input_audio):
|
|
if input_audio is None:
|
|
return {"latents": noise}
|
|
if pipe.scheduler.training:
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
input_audio, sample_rate = input_audio
|
|
input_audio = torch.clamp(input_audio, -1.0, 1.0)
|
|
if input_audio.dim() == 2:
|
|
input_audio = input_audio.unsqueeze(0)
|
|
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
|
|
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
|
|
input_latents = input_latents[:, :noise.shape[1]]
|
|
return {"input_latents": input_latents}
|
|
|
|
|
|
|
|
def model_fn_ace_step(
|
|
dit: AceStepDiTModel,
|
|
latents=None,
|
|
timestep=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
context_latents=None,
|
|
attention_mask=None,
|
|
use_gradient_checkpointing=False,
|
|
use_gradient_checkpointing_offload=False,
|
|
**kwargs,
|
|
):
|
|
decoder_outputs = dit(
|
|
hidden_states=latents,
|
|
timestep=timestep,
|
|
timestep_r=timestep,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
context_latents=context_latents,
|
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
)[0]
|
|
return decoder_outputs
|