mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
codes
This commit is contained in:
@@ -106,6 +106,9 @@ class AceStepPipeline(BasePipeline):
|
||||
audio_cover_strength: float = 1.0,
|
||||
# Audio codes
|
||||
audio_code_string: Optional[str] = None,
|
||||
# Inpainting
|
||||
repainting_ranges: Optional[List[Tuple[float, float]]] = None,
|
||||
repainting_strength: float = 1.0,
|
||||
# Shape
|
||||
duration: int = 60,
|
||||
# Audio Meta
|
||||
@@ -134,9 +137,8 @@ class AceStepPipeline(BasePipeline):
|
||||
"lyrics": lyrics,
|
||||
"task_type": task_type,
|
||||
"reference_audios": reference_audios,
|
||||
"src_audio": src_audio,
|
||||
"audio_cover_strength": audio_cover_strength,
|
||||
"audio_code_string": audio_code_string,
|
||||
"src_audio": src_audio, "audio_cover_strength": audio_cover_strength, "audio_code_string": audio_code_string,
|
||||
"repainting_ranges": repainting_ranges, "repainting_strength": repainting_strength,
|
||||
"duration": duration,
|
||||
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
|
||||
"seed": seed,
|
||||
@@ -162,9 +164,8 @@ class AceStepPipeline(BasePipeline):
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id,
|
||||
)
|
||||
inputs_shared["latents"] = self.step(
|
||||
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
|
||||
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# 5. VAE 解码
|
||||
self.load_models_to_device(['vae'])
|
||||
@@ -201,12 +202,17 @@ class AceStepUnit_TaskTypeChecker(PipelineUnit):
|
||||
"""Check and compute sequence length from duration."""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("task_type",),
|
||||
input_params=("task_type", "src_audio", "repainting_ranges", "audio_code_string"),
|
||||
output_params=("task_type",),
|
||||
)
|
||||
|
||||
def process(self, pipe, task_type):
|
||||
def process(self, pipe, task_type, src_audio, repainting_ranges, audio_code_string):
|
||||
assert task_type in ["text2music", "cover", "repaint"], f"Unsupported task_type: {task_type}"
|
||||
if task_type == "cover":
|
||||
assert (src_audio is not None) or (audio_code_string is not None), "For cover task, either src_audio or audio_code_string must be provided."
|
||||
elif task_type == "repaint":
|
||||
assert src_audio is not None, "For repaint task, src_audio must be provided."
|
||||
assert repainting_ranges is not None and len(repainting_ranges) > 0, "For repaint task, inpainting_ranges must be provided and non-empty."
|
||||
return {}
|
||||
|
||||
|
||||
@@ -399,7 +405,7 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("duration", "src_audio", "audio_code_string"),
|
||||
input_params=("duration", "src_audio", "audio_code_string", "task_type", "repainting_ranges", "repainting_strength"),
|
||||
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
|
||||
onload_model_names=("vae", "tokenizer_model",),
|
||||
)
|
||||
@@ -435,9 +441,46 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
raise ValueError(f"Invalid audio_code_string format: {e}")
|
||||
return codes
|
||||
|
||||
def process(self, pipe, duration, src_audio, audio_code_string):
|
||||
def pad_src_audio(self, pipe, src_audio, task_type, repainting_ranges):
|
||||
if task_type != "repaint" or repainting_ranges is None:
|
||||
return src_audio, repainting_ranges, None, None
|
||||
min_left = min([start for start, end in repainting_ranges])
|
||||
max_right = max([end for start, end in repainting_ranges])
|
||||
total_length = src_audio.shape[-1] // pipe.vae.sampling_rate
|
||||
pad_left = max(0, -min_left)
|
||||
pad_right = max(0, max_right - total_length)
|
||||
if pad_left > 0 or pad_right > 0:
|
||||
padding_frames_left, padding_frames_right = pad_left * pipe.vae.sampling_rate, pad_right * pipe.vae.sampling_rate
|
||||
src_audio = F.pad(src_audio, (padding_frames_left, padding_frames_right), value=0.0)
|
||||
repainting_ranges = [(start + pad_left, end + pad_left) for start, end in repainting_ranges]
|
||||
return src_audio, repainting_ranges, pad_left, pad_right
|
||||
|
||||
def parse_repaint_masks(self, pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right):
|
||||
if task_type != "repaint" or repainting_ranges is None:
|
||||
return None, src_latents
|
||||
# let repainting area be repainting_strength, non-repainting area be 0.0, and blend at the boundary with cf_frames.
|
||||
max_latent_length = src_latents.shape[1]
|
||||
denoise_mask = torch.zeros((1, max_latent_length, 1), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
for start, end in repainting_ranges:
|
||||
start_frame = start * pipe.vae.sampling_rate // 1920
|
||||
end_frame = end * pipe.vae.sampling_rate // 1920
|
||||
denoise_mask[:, start_frame:end_frame, :] = repainting_strength
|
||||
# set padding areas to 1.0 (full repaint) to avoid artifacts at the boundaries caused by padding
|
||||
pad_left_frames = pad_left * pipe.vae.sampling_rate // 1920
|
||||
pad_right_frames = pad_right * pipe.vae.sampling_rate // 1920
|
||||
denoise_mask[:, :pad_left_frames, :] = 1
|
||||
denoise_mask[:, max_latent_length - pad_right_frames:, :] = 1
|
||||
|
||||
silent_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
||||
src_latents = src_latents * (1 - denoise_mask) + silent_latents * denoise_mask
|
||||
return denoise_mask, src_latents
|
||||
|
||||
def process(self, pipe, duration, src_audio, audio_code_string, task_type=None, repainting_ranges=None, repainting_strength=None):
|
||||
# get src_latents from audio_code_string > src_audio > silence
|
||||
source_latents = None
|
||||
denoise_mask = None
|
||||
if audio_code_string is not None:
|
||||
# use audio_cede_string to get src_latents.
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
code_ids = self._parse_audio_code_string(audio_code_string)
|
||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
||||
@@ -448,33 +491,42 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
src_latents = pipe.tokenizer_model.detokenizer(quantized).to(pipe.device)
|
||||
max_latent_length = src_latents.shape[1]
|
||||
elif src_audio is not None:
|
||||
# use src_audio to get src_latents.
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
src_audio = src_audio.unsqueeze(0) if src_audio.dim() == 2 else src_audio
|
||||
src_audio = torch.clamp(src_audio, -1.0, 1.0)
|
||||
|
||||
src_audio, repainting_ranges, pad_left, pad_right = self.pad_src_audio(pipe, src_audio, task_type, repainting_ranges)
|
||||
|
||||
src_latents = pipe.vae.encode(src_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
|
||||
lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
|
||||
src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
|
||||
source_latents = src_latents # cache for potential use in audio inpainting tasks
|
||||
denoise_mask, src_latents = self.parse_repaint_masks(pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right)
|
||||
if task_type == "cover":
|
||||
lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
|
||||
src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
|
||||
max_latent_length = src_latents.shape[1]
|
||||
else:
|
||||
# use silence latents.
|
||||
max_latent_length = int(duration * pipe.sample_rate // 1920)
|
||||
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
||||
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
|
||||
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
|
||||
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
||||
return {"context_latents": context_latents, "attention_mask": attention_mask}
|
||||
return {"context_latents": context_latents, "attention_mask": attention_mask, "src_latents": source_latents, "denoise_mask": denoise_mask}
|
||||
|
||||
|
||||
class AceStepUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("context_latents", "seed", "rand_device"),
|
||||
input_params=("context_latents", "seed", "rand_device", "src_latents"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe, context_latents, seed, rand_device):
|
||||
def process(self, pipe, context_latents, seed, rand_device, src_latents):
|
||||
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
|
||||
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
noise = pipe.scheduler.add_noise(context_latents[:, :, :src_latents_shape[-1]], noise, timestep=pipe.scheduler.timesteps[0])
|
||||
if src_latents is not None:
|
||||
noise = pipe.scheduler.add_noise(src_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
@@ -502,7 +554,6 @@ class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
||||
return {"input_latents": input_latents}
|
||||
|
||||
|
||||
|
||||
def model_fn_ace_step(
|
||||
dit: AceStepDiTModel,
|
||||
latents=None,
|
||||
|
||||
Reference in New Issue
Block a user