mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
ace-step train
This commit is contained in:
@@ -7,6 +7,7 @@ import re
|
||||
import torch
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
@@ -89,13 +90,14 @@ class AceStepPipeline(BasePipeline):
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
# Lyrics
|
||||
lyrics: str = "",
|
||||
# Task type
|
||||
task_type: Optional[str] = "text2music",
|
||||
# Reference audio
|
||||
reference_audios: List[torch.Tensor] = None,
|
||||
# Src audio
|
||||
# Source audio
|
||||
src_audio: torch.Tensor = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Audio codes
|
||||
@@ -126,6 +128,7 @@ class AceStepPipeline(BasePipeline):
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"lyrics": lyrics,
|
||||
"task_type": task_type,
|
||||
"reference_audios": reference_audios,
|
||||
"src_audio": src_audio,
|
||||
"audio_code_string": audio_code_string,
|
||||
@@ -147,7 +150,7 @@ class AceStepPipeline(BasePipeline):
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.to(dtype=self.torch_dtype, device=self.device)
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
@@ -182,13 +185,14 @@ class AceStepUnit_TaskTypeChecker(PipelineUnit):
|
||||
"""Check and compute sequence length from duration."""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("src_audio", "audio_code_string"),
|
||||
input_params=("audio_code_string"),
|
||||
output_params=("task_type",),
|
||||
)
|
||||
|
||||
def process(self, pipe, src_audio, audio_code_string):
|
||||
def process(self, pipe, audio_code_string):
|
||||
if pipe.scheduler.training:
|
||||
return {"task_type": "text2music"}
|
||||
if audio_code_string is not None:
|
||||
print("audio_code_string detected, setting task_type to 'cover'")
|
||||
task_type = "cover"
|
||||
else:
|
||||
task_type = "text2music"
|
||||
@@ -200,7 +204,6 @@ class AceStepUnit_PromptEmbedder(PipelineUnit):
|
||||
INSTRUCTION_MAP = {
|
||||
"text2music": "Fill the audio semantic mask based on the given conditions:",
|
||||
"cover": "Generate audio semantic tokens based on the given conditions:",
|
||||
|
||||
"repaint": "Repaint the mask area based on the given conditions:",
|
||||
"extract": "Extract the {TRACK_NAME} track from the audio:",
|
||||
"extract_default": "Extract the track from the audio:",
|
||||
@@ -292,6 +295,7 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
||||
def process(self, pipe, reference_audios):
|
||||
pipe.load_models_to_device(['vae'])
|
||||
if reference_audios is not None and len(reference_audios) > 0:
|
||||
raise NotImplementedError("Reference audio embedding is not implemented yet.")
|
||||
# TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask
|
||||
pass
|
||||
else:
|
||||
@@ -299,6 +303,49 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
||||
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
|
||||
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
|
||||
|
||||
# def process_reference_audio(self, reference_audios) -> Optional[torch.Tensor]:
|
||||
|
||||
# try:
|
||||
# audio_np, sr = _read_audio_file(audio_file)
|
||||
# audio = self._numpy_to_channels_first(audio_np)
|
||||
|
||||
# logger.debug(
|
||||
# f"[process_reference_audio] Reference audio shape: {audio.shape}"
|
||||
# )
|
||||
# logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
|
||||
# logger.debug(
|
||||
# f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / sr:.6f} seconds"
|
||||
# )
|
||||
|
||||
# audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
||||
# if self.is_silence(audio):
|
||||
# return None
|
||||
|
||||
# target_frames = 30 * 48000
|
||||
# segment_frames = 10 * 48000
|
||||
|
||||
# if audio.shape[-1] < target_frames:
|
||||
# repeat_times = math.ceil(target_frames / audio.shape[-1])
|
||||
# audio = audio.repeat(1, repeat_times)
|
||||
|
||||
# total_frames = audio.shape[-1]
|
||||
# segment_size = total_frames // 3
|
||||
|
||||
# front_start = random.randint(0, max(0, segment_size - segment_frames))
|
||||
# front_audio = audio[:, front_start : front_start + segment_frames]
|
||||
|
||||
# middle_start = segment_size + random.randint(
|
||||
# 0, max(0, segment_size - segment_frames)
|
||||
# )
|
||||
# middle_audio = audio[:, middle_start : middle_start + segment_frames]
|
||||
|
||||
# back_start = 2 * segment_size + random.randint(
|
||||
# 0, max(0, (total_frames - 2 * segment_size) - segment_frames)
|
||||
# )
|
||||
# back_audio = audio[:, back_start : back_start + segment_frames]
|
||||
|
||||
# return torch.cat([front_audio, middle_audio, back_audio], dim=-1)
|
||||
|
||||
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Infer packed reference-audio latents and order mask."""
|
||||
refer_audio_order_mask = []
|
||||
@@ -401,8 +448,8 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
|
||||
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
|
||||
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
||||
elif src_audio is not None:
|
||||
raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
|
||||
# elif src_audio is not None:
|
||||
# raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
|
||||
else:
|
||||
max_latent_length = duration * pipe.sample_rate // 1920
|
||||
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
||||
@@ -435,8 +482,16 @@ class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
||||
def process(self, pipe, noise, input_audio):
|
||||
if input_audio is None:
|
||||
return {"latents": noise}
|
||||
# TODO: support for train
|
||||
return {"latents": noise, "input_latents": None}
|
||||
if pipe.scheduler.training:
|
||||
pipe.load_models_to_device(['vae'])
|
||||
input_audio, sample_rate = input_audio
|
||||
input_audio = torch.clamp(input_audio, -1.0, 1.0)
|
||||
if input_audio.dim() == 2:
|
||||
input_audio = input_audio.unsqueeze(0)
|
||||
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
|
||||
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
|
||||
input_latents = input_latents[:, :noise.shape[1]]
|
||||
return {"input_latents": input_latents}
|
||||
|
||||
|
||||
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
|
||||
@@ -494,7 +549,6 @@ def model_fn_ace_step(
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
timestep = timestep.unsqueeze(0)
|
||||
decoder_outputs = dit(
|
||||
hidden_states=latents,
|
||||
timestep=timestep,
|
||||
|
||||
Reference in New Issue
Block a user