mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
t2m
This commit is contained in:
@@ -42,10 +42,10 @@ class AceStepPipeline(BasePipeline):
|
||||
AceStepUnit_PromptEmbedder(),
|
||||
AceStepUnit_ReferenceAudioEmbedder(),
|
||||
AceStepUnit_ConditionEmbedder(),
|
||||
AceStepUnit_AudioCodeDecoder(),
|
||||
AceStepUnit_ContextLatentBuilder(),
|
||||
AceStepUnit_NoiseInitializer(),
|
||||
AceStepUnit_InputAudioEmbedder(),
|
||||
AceStepUnit_AudioCodeDecoder(),
|
||||
]
|
||||
self.model_fn = model_fn_ace_step
|
||||
self.compilable_models = ["dit"]
|
||||
@@ -92,27 +92,27 @@ class AceStepPipeline(BasePipeline):
|
||||
cfg_scale: float = 1.0,
|
||||
# Lyrics
|
||||
lyrics: str = "",
|
||||
# Reference audio (optional, for timbre conditioning)
|
||||
# Reference audio
|
||||
reference_audios: List[torch.Tensor] = None,
|
||||
# Src audio
|
||||
src_audio: torch.Tensor = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Simple Mode: LLM-generated audio codes (optional)
|
||||
audio_codes: str = None,
|
||||
# Audio codes
|
||||
audio_code_string: Optional[str] = None,
|
||||
# Shape
|
||||
duration: int = 60,
|
||||
# Audio Meta
|
||||
bpm: Optional[int] = 100,
|
||||
keyscale: Optional[str] = "B minor",
|
||||
timesignature: Optional[str] = "4",
|
||||
vocal_language: Optional[str] = 'zh',
|
||||
vocal_language: Optional[str] = 'unknown',
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 8,
|
||||
# Scheduler-specific parameters
|
||||
shift: float = 3.0,
|
||||
shift: float = 1.0,
|
||||
# Progress
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
@@ -120,14 +120,14 @@ class AceStepPipeline(BasePipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
|
||||
|
||||
# 2. 三字典输入
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_posi = {"prompt": prompt, "positive": True}
|
||||
inputs_nega = {"positive": False}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"lyrics": lyrics,
|
||||
"reference_audios": reference_audios,
|
||||
"src_audio": src_audio,
|
||||
"audio_codes": audio_codes,
|
||||
"audio_code_string": audio_code_string,
|
||||
"duration": duration,
|
||||
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
|
||||
"seed": seed,
|
||||
@@ -145,12 +145,13 @@ class AceStepPipeline(BasePipeline):
|
||||
# 4. Denoise loop
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
self.momentum_buffer = MomentumBuffer()
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
**models, timestep=timestep, progress_id=progress_id,
|
||||
)
|
||||
inputs_shared["latents"] = self.step(
|
||||
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
|
||||
@@ -163,39 +164,12 @@ class AceStepPipeline(BasePipeline):
|
||||
vae_output = self.vae.decode(latents)
|
||||
# VAE returns OobleckDecoderOutput with .sample attribute
|
||||
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
|
||||
audio_output = self.normalize_audio(audio_output, target_db=-1.0)
|
||||
audio = self.output_audio_format_check(audio_output)
|
||||
|
||||
# Peak normalization to match target library behavior
|
||||
audio = self.normalize_audio(audio, target_db=-1.0)
|
||||
|
||||
self.load_models_to_device([])
|
||||
return audio
|
||||
|
||||
def output_audio_format_check(self, audio_output):
|
||||
"""Convert VAE output to standard audio format [C, T], float32.
|
||||
|
||||
VAE decode outputs [B, C, T] (audio waveform).
|
||||
We squeeze batch dim and return [C, T].
|
||||
"""
|
||||
if audio_output.ndim == 3:
|
||||
audio_output = audio_output.squeeze(0)
|
||||
return audio_output.float()
|
||||
|
||||
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
|
||||
"""Apply peak normalization to audio data, matching target library behavior.
|
||||
|
||||
Target library reference: `acestep/audio_utils.py:normalize_audio()`
|
||||
peak = max(abs(audio))
|
||||
gain = 10^(target_db/20) / peak
|
||||
audio = audio * gain
|
||||
|
||||
Args:
|
||||
audio: Audio tensor [C, T]
|
||||
target_db: Target peak level in dB (default: -1.0)
|
||||
|
||||
Returns:
|
||||
Normalized audio tensor
|
||||
"""
|
||||
peak = torch.max(torch.abs(audio))
|
||||
if peak < 1e-6:
|
||||
return audio
|
||||
@@ -203,17 +177,46 @@ class AceStepPipeline(BasePipeline):
|
||||
gain = target_amp / peak
|
||||
return audio * gain
|
||||
|
||||
|
||||
class AceStepUnit_TaskTypeChecker(PipelineUnit):
|
||||
"""Check and compute sequence length from duration."""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("src_audio", "audio_code_string"),
|
||||
output_params=("task_type",),
|
||||
)
|
||||
|
||||
def process(self, pipe, src_audio, audio_code_string):
|
||||
if audio_code_string is not None:
|
||||
print("audio_code_string detected, setting task_type to 'cover'")
|
||||
task_type = "cover"
|
||||
else:
|
||||
task_type = "text2music"
|
||||
return {"task_type": task_type}
|
||||
|
||||
|
||||
class AceStepUnit_PromptEmbedder(PipelineUnit):
|
||||
SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
|
||||
INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
|
||||
INSTRUCTION_MAP = {
|
||||
"text2music": "Fill the audio semantic mask based on the given conditions:",
|
||||
"cover": "Generate audio semantic tokens based on the given conditions:",
|
||||
|
||||
"repaint": "Repaint the mask area based on the given conditions:",
|
||||
"extract": "Extract the {TRACK_NAME} track from the audio:",
|
||||
"extract_default": "Extract the track from the audio:",
|
||||
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
|
||||
"lego_default": "Generate the track based on the audio context:",
|
||||
"complete": "Complete the input track with {TRACK_CLASSES}:",
|
||||
"complete_default": "Complete the input track:",
|
||||
}
|
||||
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "prompt"},
|
||||
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language"),
|
||||
input_params_posi={"prompt": "prompt", "positive": "positive"},
|
||||
input_params_nega={"prompt": "prompt", "positive": "positive"},
|
||||
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language", "task_type"),
|
||||
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
@@ -256,10 +259,13 @@ class AceStepUnit_PromptEmbedder(PipelineUnit):
|
||||
f"- duration: {duration}\n"
|
||||
)
|
||||
|
||||
def process(self, pipe, prompt, lyrics, duration, bpm, keyscale, timesignature, vocal_language):
|
||||
def process(self, pipe, prompt, positive, lyrics, duration, bpm, keyscale, timesignature, vocal_language, task_type):
|
||||
if not positive:
|
||||
return {}
|
||||
pipe.load_models_to_device(['text_encoder'])
|
||||
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
|
||||
prompt = self.SFT_GEN_PROMPT.format(self.INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
|
||||
INSTRUCTION = self.INSTRUCTION_MAP.get(task_type, self.INSTRUCTION_MAP["text2music"])
|
||||
prompt = self.SFT_GEN_PROMPT.format(INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
|
||||
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
|
||||
|
||||
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
|
||||
@@ -350,31 +356,32 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
|
||||
input_params_nega={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
|
||||
input_params=("reference_latents", "refer_audio_order_mask"),
|
||||
take_over=True,
|
||||
output_params=("encoder_hidden_states", "encoder_attention_mask"),
|
||||
onload_model_names=("conditioner",),
|
||||
)
|
||||
|
||||
def process(self, pipe, text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, reference_latents, refer_audio_order_mask):
|
||||
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
|
||||
pipe.load_models_to_device(['conditioner'])
|
||||
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
|
||||
text_hidden_states=text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
lyric_hidden_states=lyric_hidden_states,
|
||||
lyric_attention_mask=lyric_attention_mask,
|
||||
reference_latents=reference_latents,
|
||||
refer_audio_order_mask=refer_audio_order_mask,
|
||||
text_hidden_states=inputs_posi.get("text_hidden_states", None),
|
||||
text_attention_mask=inputs_posi.get("text_attention_mask", None),
|
||||
lyric_hidden_states=inputs_posi.get("lyric_hidden_states", None),
|
||||
lyric_attention_mask=inputs_posi.get("lyric_attention_mask", None),
|
||||
reference_latents=inputs_shared.get("reference_latents", None),
|
||||
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
|
||||
)
|
||||
return {"encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask}
|
||||
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
|
||||
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
|
||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states)
|
||||
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("duration", "src_audio"),
|
||||
input_params=("duration", "src_audio", "lm_hints"),
|
||||
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
|
||||
)
|
||||
|
||||
@@ -386,9 +393,15 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
tiled = pipe.silence_latent[0].repeat(repeats, 1)
|
||||
return tiled[:length, :]
|
||||
|
||||
def process(self, pipe, duration, src_audio):
|
||||
if src_audio is not None:
|
||||
raise NotImplementedError("Src audio conditioning is not implemented yet. Please set src_audio to None.")
|
||||
def process(self, pipe, duration, src_audio, lm_hints):
|
||||
if lm_hints is not None:
|
||||
max_latent_length = lm_hints.shape[1]
|
||||
src_latents = lm_hints.clone()
|
||||
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
|
||||
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
|
||||
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
||||
elif src_audio is not None:
|
||||
raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
|
||||
else:
|
||||
max_latent_length = duration * pipe.sample_rate // 1920
|
||||
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
||||
@@ -410,6 +423,7 @@ class AceStepUnit_NoiseInitializer(PipelineUnit):
|
||||
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -427,8 +441,8 @@ class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
||||
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("audio_codes", "seq_len", "silence_latent"),
|
||||
output_params=("lm_hints_25Hz",),
|
||||
input_params=("audio_code_string",),
|
||||
output_params=("lm_hints",),
|
||||
onload_model_names=("tokenizer_model",),
|
||||
)
|
||||
|
||||
@@ -437,45 +451,29 @@ class AceStepUnit_AudioCodeDecoder(PipelineUnit):
|
||||
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
|
||||
if not code_str:
|
||||
return []
|
||||
codes = []
|
||||
max_audio_code = 63999
|
||||
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
|
||||
code_value = int(x)
|
||||
codes.append(max(0, min(code_value, max_audio_code)))
|
||||
try:
|
||||
codes = []
|
||||
max_audio_code = 63999
|
||||
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
|
||||
code_value = int(x)
|
||||
codes.append(max(0, min(code_value, max_audio_code)))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid audio_code_string format: {e}")
|
||||
return codes
|
||||
|
||||
def process(self, pipe, audio_codes, seq_len, silence_latent):
|
||||
if audio_codes is None or not audio_codes.strip():
|
||||
return {"lm_hints_25Hz": None}
|
||||
|
||||
code_ids = self._parse_audio_code_string(audio_codes)
|
||||
def process(self, pipe, audio_code_string):
|
||||
if audio_code_string is None or not audio_code_string.strip():
|
||||
return {"lm_hints": None}
|
||||
code_ids = self._parse_audio_code_string(audio_code_string)
|
||||
if len(code_ids) == 0:
|
||||
return {"lm_hints_25Hz": None}
|
||||
return {"lm_hints": None}
|
||||
|
||||
pipe.load_models_to_device(["tokenizer_model"])
|
||||
|
||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
||||
detokenizer = pipe.tokenizer_model.detokenizer
|
||||
|
||||
indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
|
||||
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
|
||||
|
||||
quantized = quantizer.get_output_from_indices(indices) # [1, N, 2048]
|
||||
if quantized.dtype != pipe.torch_dtype:
|
||||
quantized = quantized.to(pipe.torch_dtype)
|
||||
|
||||
lm_hints = detokenizer(quantized) # [1, N*5, 64]
|
||||
|
||||
# Pad or truncate to seq_len
|
||||
current_len = lm_hints.shape[1]
|
||||
if current_len < seq_len:
|
||||
pad_len = seq_len - current_len
|
||||
pad = silence_latent[:, :pad_len, :]
|
||||
lm_hints = torch.cat([lm_hints, pad], dim=1)
|
||||
elif current_len > seq_len:
|
||||
lm_hints = lm_hints[:, :seq_len, :]
|
||||
|
||||
return {"lm_hints_25Hz": lm_hints}
|
||||
quantized = pipe.tokenizer_model.tokenizer.quantizer.get_output_from_indices(indices).to(pipe.torch_dtype) # [1, N, 2048]
|
||||
lm_hints = pipe.tokenizer_model.detokenizer(quantized) # [1, N*5, 64]
|
||||
return {"lm_hints": lm_hints}
|
||||
|
||||
|
||||
def model_fn_ace_step(
|
||||
@@ -499,5 +497,7 @@ def model_fn_ace_step(
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
context_latents=context_latents,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)[0]
|
||||
return decoder_outputs
|
||||
|
||||
Reference in New Issue
Block a user