mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
noncover
This commit is contained in:
@@ -9,6 +9,8 @@ from typing import Optional, Dict, Any, List, Tuple
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
@@ -41,11 +43,11 @@ class AceStepPipeline(BasePipeline):
|
||||
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
AceStepUnit_TaskTypeChecker(),
|
||||
AceStepUnit_PromptEmbedder(),
|
||||
AceStepUnit_ReferenceAudioEmbedder(),
|
||||
AceStepUnit_ConditionEmbedder(),
|
||||
AceStepUnit_AudioCodeDecoder(),
|
||||
AceStepUnit_ContextLatentBuilder(),
|
||||
AceStepUnit_ConditionEmbedder(),
|
||||
AceStepUnit_NoiseInitializer(),
|
||||
AceStepUnit_InputAudioEmbedder(),
|
||||
]
|
||||
@@ -100,7 +102,8 @@ class AceStepPipeline(BasePipeline):
|
||||
reference_audios: List[torch.Tensor] = None,
|
||||
# Source audio
|
||||
src_audio: torch.Tensor = None,
|
||||
denoising_strength: float = 1.0,
|
||||
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
|
||||
audio_cover_strength: float = 1.0,
|
||||
# Audio codes
|
||||
audio_code_string: Optional[str] = None,
|
||||
# Shape
|
||||
@@ -121,7 +124,7 @@ class AceStepPipeline(BasePipeline):
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# 1. Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
|
||||
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
|
||||
|
||||
# 2. 三字典输入
|
||||
inputs_posi = {"prompt": prompt, "positive": True}
|
||||
@@ -132,6 +135,7 @@ class AceStepPipeline(BasePipeline):
|
||||
"task_type": task_type,
|
||||
"reference_audios": reference_audios,
|
||||
"src_audio": src_audio,
|
||||
"audio_cover_strength": audio_cover_strength,
|
||||
"audio_code_string": audio_code_string,
|
||||
"duration": duration,
|
||||
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
|
||||
@@ -152,6 +156,7 @@ class AceStepPipeline(BasePipeline):
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
self.switch_noncover_condition(inputs_shared, inputs_posi, inputs_nega, progress_id)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
@@ -181,23 +186,28 @@ class AceStepPipeline(BasePipeline):
|
||||
gain = target_amp / peak
|
||||
return audio * gain
|
||||
|
||||
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
|
||||
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0 or inputs_shared.get("shared_noncover", None) is None:
|
||||
return
|
||||
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
|
||||
if progress_id >= cover_steps:
|
||||
inputs_shared.update(inputs_shared.pop("shared_noncover", {}))
|
||||
inputs_posi.update(inputs_shared.pop("posi_noncover", {}))
|
||||
if inputs_shared["cfg_scale"] != 1.0:
|
||||
inputs_nega.update(inputs_shared.pop("nega_noncover", {}))
|
||||
|
||||
|
||||
class AceStepUnit_TaskTypeChecker(PipelineUnit):
|
||||
"""Check and compute sequence length from duration."""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("audio_code_string"),
|
||||
input_params=("task_type",),
|
||||
output_params=("task_type",),
|
||||
)
|
||||
|
||||
def process(self, pipe, audio_code_string):
|
||||
if pipe.scheduler.training:
|
||||
return {"task_type": "text2music"}
|
||||
if audio_code_string is not None:
|
||||
task_type = "cover"
|
||||
else:
|
||||
task_type = "text2music"
|
||||
return {"task_type": task_type}
|
||||
def process(self, pipe, task_type):
|
||||
assert task_type in ["text2music", "cover", "repaint"], f"Unsupported task_type: {task_type}"
|
||||
return {}
|
||||
|
||||
|
||||
class AceStepUnit_PromptEmbedder(PipelineUnit):
|
||||
@@ -364,14 +374,34 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
if inputs_shared["cfg_scale"] != 1.0:
|
||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
|
||||
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
||||
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
|
||||
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
|
||||
pipe, inputs_posi["prompt"], True, inputs_shared["lyrics"], inputs_shared["duration"],
|
||||
inputs_shared["bpm"], inputs_shared["keyscale"], inputs_shared["timesignature"],
|
||||
inputs_shared["vocal_language"], "text2music")
|
||||
encoder_hidden_states_noncover, encoder_attention_mask_noncover = pipe.conditioner(
|
||||
**hidden_states_noncover,
|
||||
reference_latents=inputs_shared.get("reference_latents", None),
|
||||
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
|
||||
)
|
||||
duration = inputs_shared["context_latents"].shape[1] * 1920 / pipe.vae.sampling_rate
|
||||
context_latents_noncover = AceStepUnit_ContextLatentBuilder().process(pipe, duration, None, None)["context_latents"]
|
||||
inputs_shared["shared_noncover"] = {"context_latents": context_latents_noncover}
|
||||
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
|
||||
if inputs_shared["cfg_scale"] != 1.0:
|
||||
inputs_shared["nega_noncover"] = {
|
||||
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device),
|
||||
"encoder_attention_mask": encoder_attention_mask_noncover,
|
||||
}
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("duration", "src_audio", "lm_hints"),
|
||||
input_params=("duration", "src_audio", "audio_code_string"),
|
||||
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
|
||||
onload_model_names=("vae", "tokenizer_model",),
|
||||
)
|
||||
|
||||
def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
|
||||
@@ -382,66 +412,13 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
tiled = pipe.silence_latent[0].repeat(repeats, 1)
|
||||
return tiled[:length, :]
|
||||
|
||||
def process(self, pipe, duration, src_audio, lm_hints):
|
||||
if lm_hints is not None:
|
||||
max_latent_length = lm_hints.shape[1]
|
||||
src_latents = lm_hints.clone()
|
||||
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
|
||||
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
|
||||
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
||||
# elif src_audio is not None:
|
||||
# raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
|
||||
else:
|
||||
max_latent_length = duration * pipe.sample_rate // 1920
|
||||
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
||||
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
|
||||
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
|
||||
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
||||
return {"context_latents": context_latents, "attention_mask": attention_mask}
|
||||
|
||||
|
||||
class AceStepUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("context_latents", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe, context_latents, seed, rand_device):
|
||||
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
|
||||
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("noise", "input_audio"),
|
||||
output_params=("latents", "input_latents"),
|
||||
)
|
||||
|
||||
def process(self, pipe, noise, input_audio):
|
||||
if input_audio is None:
|
||||
return {"latents": noise}
|
||||
if pipe.scheduler.training:
|
||||
pipe.load_models_to_device(['vae'])
|
||||
input_audio, sample_rate = input_audio
|
||||
input_audio = torch.clamp(input_audio, -1.0, 1.0)
|
||||
if input_audio.dim() == 2:
|
||||
input_audio = input_audio.unsqueeze(0)
|
||||
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
|
||||
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
|
||||
input_latents = input_latents[:, :noise.shape[1]]
|
||||
return {"input_latents": input_latents}
|
||||
|
||||
|
||||
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("audio_code_string",),
|
||||
output_params=("lm_hints",),
|
||||
onload_model_names=("tokenizer_model",),
|
||||
)
|
||||
def tokenize(self, tokenizer, x, silence_latent, pool_window_size):
|
||||
if x.shape[1] % pool_window_size != 0:
|
||||
pad_len = pool_window_size - (x.shape[1] % pool_window_size)
|
||||
x = torch.cat([x, silence_latent[:1,:pad_len].repeat(x.shape[0],1,1)], dim=1)
|
||||
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=pool_window_size)
|
||||
quantized, indices = tokenizer(x)
|
||||
return quantized
|
||||
|
||||
@staticmethod
|
||||
def _parse_audio_code_string(code_str: str) -> list:
|
||||
@@ -458,24 +435,72 @@ class AceStepUnit_AudioCodeDecoder(PipelineUnit):
|
||||
raise ValueError(f"Invalid audio_code_string format: {e}")
|
||||
return codes
|
||||
|
||||
def process(self, pipe, audio_code_string):
|
||||
if audio_code_string is None or not audio_code_string.strip():
|
||||
return {"lm_hints": None}
|
||||
code_ids = self._parse_audio_code_string(audio_code_string)
|
||||
if len(code_ids) == 0:
|
||||
return {"lm_hints": None}
|
||||
def process(self, pipe, duration, src_audio, audio_code_string):
|
||||
# get src_latents from audio_code_string > src_audio > silence
|
||||
if audio_code_string is not None:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
code_ids = self._parse_audio_code_string(audio_code_string)
|
||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
||||
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
|
||||
codes = quantizer.get_codes_from_indices(indices)
|
||||
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
|
||||
quantized = quantizer.project_out(quantized)
|
||||
src_latents = pipe.tokenizer_model.detokenizer(quantized).to(pipe.device)
|
||||
max_latent_length = src_latents.shape[1]
|
||||
elif src_audio is not None:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
src_audio = src_audio.unsqueeze(0) if src_audio.dim() == 2 else src_audio
|
||||
src_audio = torch.clamp(src_audio, -1.0, 1.0)
|
||||
src_latents = pipe.vae.encode(src_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
|
||||
lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
|
||||
src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
|
||||
max_latent_length = src_latents.shape[1]
|
||||
else:
|
||||
max_latent_length = int(duration * pipe.sample_rate // 1920)
|
||||
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
||||
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
|
||||
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
|
||||
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
||||
return {"context_latents": context_latents, "attention_mask": attention_mask}
|
||||
|
||||
pipe.load_models_to_device(["tokenizer_model"])
|
||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
||||
detokenizer = pipe.tokenizer_model.detokenizer
|
||||
|
||||
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
|
||||
codes = quantizer.get_codes_from_indices(indices)
|
||||
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
|
||||
quantized = quantizer.project_out(quantized)
|
||||
class AceStepUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("context_latents", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe, context_latents, seed, rand_device):
|
||||
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
|
||||
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
noise = pipe.scheduler.add_noise(context_latents[:, :, :src_latents_shape[-1]], noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
||||
"""Only for training."""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("noise", "input_audio"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",),
|
||||
)
|
||||
|
||||
def process(self, pipe, noise, input_audio):
|
||||
if input_audio is None:
|
||||
return {"latents": noise}
|
||||
if pipe.scheduler.training:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_audio, sample_rate = input_audio
|
||||
input_audio = torch.clamp(input_audio, -1.0, 1.0)
|
||||
if input_audio.dim() == 2:
|
||||
input_audio = input_audio.unsqueeze(0)
|
||||
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
|
||||
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
|
||||
input_latents = input_latents[:, :noise.shape[1]]
|
||||
return {"input_latents": input_latents}
|
||||
|
||||
lm_hints = detokenizer(quantized).to(pipe.device)
|
||||
return {"lm_hints": lm_hints}
|
||||
|
||||
|
||||
def model_fn_ace_step(
|
||||
|
||||
Reference in New Issue
Block a user