mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
t2m
This commit is contained in:
@@ -962,37 +962,6 @@ ace_step_series = [
|
||||
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter",
|
||||
},
|
||||
# === LLM variants ===
|
||||
{
|
||||
# Example: ModelConfig(model_id="ACE-Step/acestep-5Hz-lm-0.6B", origin_file_pattern="model.safetensors")
|
||||
"model_hash": "f3ab4bef9e00745fd0fea7aa8b2a4041",
|
||||
"model_name": "ace_step_lm",
|
||||
"model_class": "diffsynth.models.ace_step_lm.AceStepLM",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_lm.ace_step_lm_converter",
|
||||
"extra_kwargs": {
|
||||
"variant": "acestep-5Hz-lm-0.6B",
|
||||
},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-5Hz-lm-1.7B/model.safetensors")
|
||||
"model_hash": "a14b6e422b0faa9b41e7efe0fee46766",
|
||||
"model_name": "ace_step_lm",
|
||||
"model_class": "diffsynth.models.ace_step_lm.AceStepLM",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_lm.ace_step_lm_converter",
|
||||
"extra_kwargs": {
|
||||
"variant": "acestep-5Hz-lm-1.7B",
|
||||
},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="ACE-Step/acestep-5Hz-lm-4B", origin_file_pattern="model-*.safetensors")
|
||||
"model_hash": "046a3934f2e6f2f6d450bad23b1f4933",
|
||||
"model_name": "ace_step_lm",
|
||||
"model_class": "diffsynth.models.ace_step_lm.AceStepLM",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_lm.ace_step_lm_converter",
|
||||
"extra_kwargs": {
|
||||
"variant": "acestep-5Hz-lm-4B",
|
||||
},
|
||||
},
|
||||
# === Qwen3-Embedding (text encoder) ===
|
||||
{
|
||||
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors")
|
||||
|
||||
@@ -152,7 +152,7 @@ class BasePipeline(torch.nn.Module):
|
||||
# remove batch dim
|
||||
if audio_output.ndim == 3:
|
||||
audio_output = audio_output.squeeze(0)
|
||||
return audio_output.float()
|
||||
return audio_output.float().cpu()
|
||||
|
||||
def load_models_to_device(self, model_names):
|
||||
if self.vram_management_enabled:
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
LM_CONFIGS = {
|
||||
"acestep-5Hz-lm-0.6B": {
|
||||
"hidden_size": 1024,
|
||||
"intermediate_size": 3072,
|
||||
"num_hidden_layers": 28,
|
||||
"num_attention_heads": 16,
|
||||
"layer_types": ["full_attention"] * 28,
|
||||
"max_window_layers": 28,
|
||||
},
|
||||
"acestep-5Hz-lm-1.7B": {
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 6144,
|
||||
"num_hidden_layers": 28,
|
||||
"num_attention_heads": 16,
|
||||
"layer_types": ["full_attention"] * 28,
|
||||
"max_window_layers": 28,
|
||||
},
|
||||
"acestep-5Hz-lm-4B": {
|
||||
"hidden_size": 2560,
|
||||
"intermediate_size": 9728,
|
||||
"num_hidden_layers": 36,
|
||||
"num_attention_heads": 32,
|
||||
"layer_types": ["full_attention"] * 36,
|
||||
"max_window_layers": 36,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AceStepLM(torch.nn.Module):
|
||||
"""
|
||||
Language model for ACE-Step.
|
||||
|
||||
Converts natural language prompts into structured parameters
|
||||
(caption, lyrics, bpm, keyscale, duration, timesignature, etc.)
|
||||
for ACE-Step music generation.
|
||||
|
||||
Wraps a Qwen3ForCausalLM transformers model. Config is manually
|
||||
constructed based on variant type, and model weights are loaded
|
||||
via DiffSynth's standard mechanism from safetensors files.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
variant: str = "acestep-5Hz-lm-1.7B",
|
||||
):
|
||||
super().__init__()
|
||||
from transformers import Qwen3Config, Qwen3ForCausalLM
|
||||
|
||||
config_params = LM_CONFIGS[variant]
|
||||
|
||||
config = Qwen3Config(
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
bos_token_id=151643,
|
||||
dtype="bfloat16",
|
||||
eos_token_id=151645,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=40960,
|
||||
model_type="qwen3",
|
||||
num_key_value_heads=8,
|
||||
pad_token_id=151643,
|
||||
rms_norm_eps=1e-06,
|
||||
rope_scaling=None,
|
||||
rope_theta=1000000,
|
||||
sliding_window=None,
|
||||
tie_word_embeddings=True,
|
||||
use_cache=True,
|
||||
use_sliding_window=False,
|
||||
vocab_size=217204,
|
||||
**config_params,
|
||||
)
|
||||
|
||||
self.model = Qwen3ForCausalLM(config)
|
||||
self.config = config
|
||||
@@ -42,10 +42,10 @@ class AceStepPipeline(BasePipeline):
|
||||
AceStepUnit_PromptEmbedder(),
|
||||
AceStepUnit_ReferenceAudioEmbedder(),
|
||||
AceStepUnit_ConditionEmbedder(),
|
||||
AceStepUnit_AudioCodeDecoder(),
|
||||
AceStepUnit_ContextLatentBuilder(),
|
||||
AceStepUnit_NoiseInitializer(),
|
||||
AceStepUnit_InputAudioEmbedder(),
|
||||
AceStepUnit_AudioCodeDecoder(),
|
||||
]
|
||||
self.model_fn = model_fn_ace_step
|
||||
self.compilable_models = ["dit"]
|
||||
@@ -92,27 +92,27 @@ class AceStepPipeline(BasePipeline):
|
||||
cfg_scale: float = 1.0,
|
||||
# Lyrics
|
||||
lyrics: str = "",
|
||||
# Reference audio (optional, for timbre conditioning)
|
||||
# Reference audio
|
||||
reference_audios: List[torch.Tensor] = None,
|
||||
# Src audio
|
||||
src_audio: torch.Tensor = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Simple Mode: LLM-generated audio codes (optional)
|
||||
audio_codes: str = None,
|
||||
# Audio codes
|
||||
audio_code_string: Optional[str] = None,
|
||||
# Shape
|
||||
duration: int = 60,
|
||||
# Audio Meta
|
||||
bpm: Optional[int] = 100,
|
||||
keyscale: Optional[str] = "B minor",
|
||||
timesignature: Optional[str] = "4",
|
||||
vocal_language: Optional[str] = 'zh',
|
||||
vocal_language: Optional[str] = 'unknown',
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 8,
|
||||
# Scheduler-specific parameters
|
||||
shift: float = 3.0,
|
||||
shift: float = 1.0,
|
||||
# Progress
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
@@ -120,14 +120,14 @@ class AceStepPipeline(BasePipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
|
||||
|
||||
# 2. 三字典输入
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_posi = {"prompt": prompt, "positive": True}
|
||||
inputs_nega = {"positive": False}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"lyrics": lyrics,
|
||||
"reference_audios": reference_audios,
|
||||
"src_audio": src_audio,
|
||||
"audio_codes": audio_codes,
|
||||
"audio_code_string": audio_code_string,
|
||||
"duration": duration,
|
||||
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
|
||||
"seed": seed,
|
||||
@@ -145,12 +145,13 @@ class AceStepPipeline(BasePipeline):
|
||||
# 4. Denoise loop
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
self.momentum_buffer = MomentumBuffer()
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
**models, timestep=timestep, progress_id=progress_id,
|
||||
)
|
||||
inputs_shared["latents"] = self.step(
|
||||
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
|
||||
@@ -163,39 +164,12 @@ class AceStepPipeline(BasePipeline):
|
||||
vae_output = self.vae.decode(latents)
|
||||
# VAE returns OobleckDecoderOutput with .sample attribute
|
||||
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
|
||||
audio_output = self.normalize_audio(audio_output, target_db=-1.0)
|
||||
audio = self.output_audio_format_check(audio_output)
|
||||
|
||||
# Peak normalization to match target library behavior
|
||||
audio = self.normalize_audio(audio, target_db=-1.0)
|
||||
|
||||
self.load_models_to_device([])
|
||||
return audio
|
||||
|
||||
def output_audio_format_check(self, audio_output):
|
||||
"""Convert VAE output to standard audio format [C, T], float32.
|
||||
|
||||
VAE decode outputs [B, C, T] (audio waveform).
|
||||
We squeeze batch dim and return [C, T].
|
||||
"""
|
||||
if audio_output.ndim == 3:
|
||||
audio_output = audio_output.squeeze(0)
|
||||
return audio_output.float()
|
||||
|
||||
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
|
||||
"""Apply peak normalization to audio data, matching target library behavior.
|
||||
|
||||
Target library reference: `acestep/audio_utils.py:normalize_audio()`
|
||||
peak = max(abs(audio))
|
||||
gain = 10^(target_db/20) / peak
|
||||
audio = audio * gain
|
||||
|
||||
Args:
|
||||
audio: Audio tensor [C, T]
|
||||
target_db: Target peak level in dB (default: -1.0)
|
||||
|
||||
Returns:
|
||||
Normalized audio tensor
|
||||
"""
|
||||
peak = torch.max(torch.abs(audio))
|
||||
if peak < 1e-6:
|
||||
return audio
|
||||
@@ -203,17 +177,46 @@ class AceStepPipeline(BasePipeline):
|
||||
gain = target_amp / peak
|
||||
return audio * gain
|
||||
|
||||
|
||||
class AceStepUnit_TaskTypeChecker(PipelineUnit):
|
||||
"""Check and compute sequence length from duration."""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("src_audio", "audio_code_string"),
|
||||
output_params=("task_type",),
|
||||
)
|
||||
|
||||
def process(self, pipe, src_audio, audio_code_string):
|
||||
if audio_code_string is not None:
|
||||
print("audio_code_string detected, setting task_type to 'cover'")
|
||||
task_type = "cover"
|
||||
else:
|
||||
task_type = "text2music"
|
||||
return {"task_type": task_type}
|
||||
|
||||
|
||||
class AceStepUnit_PromptEmbedder(PipelineUnit):
|
||||
SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
|
||||
INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
|
||||
INSTRUCTION_MAP = {
|
||||
"text2music": "Fill the audio semantic mask based on the given conditions:",
|
||||
"cover": "Generate audio semantic tokens based on the given conditions:",
|
||||
|
||||
"repaint": "Repaint the mask area based on the given conditions:",
|
||||
"extract": "Extract the {TRACK_NAME} track from the audio:",
|
||||
"extract_default": "Extract the track from the audio:",
|
||||
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
|
||||
"lego_default": "Generate the track based on the audio context:",
|
||||
"complete": "Complete the input track with {TRACK_CLASSES}:",
|
||||
"complete_default": "Complete the input track:",
|
||||
}
|
||||
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "prompt"},
|
||||
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language"),
|
||||
input_params_posi={"prompt": "prompt", "positive": "positive"},
|
||||
input_params_nega={"prompt": "prompt", "positive": "positive"},
|
||||
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language", "task_type"),
|
||||
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
@@ -256,10 +259,13 @@ class AceStepUnit_PromptEmbedder(PipelineUnit):
|
||||
f"- duration: {duration}\n"
|
||||
)
|
||||
|
||||
def process(self, pipe, prompt, lyrics, duration, bpm, keyscale, timesignature, vocal_language):
|
||||
def process(self, pipe, prompt, positive, lyrics, duration, bpm, keyscale, timesignature, vocal_language, task_type):
|
||||
if not positive:
|
||||
return {}
|
||||
pipe.load_models_to_device(['text_encoder'])
|
||||
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
|
||||
prompt = self.SFT_GEN_PROMPT.format(self.INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
|
||||
INSTRUCTION = self.INSTRUCTION_MAP.get(task_type, self.INSTRUCTION_MAP["text2music"])
|
||||
prompt = self.SFT_GEN_PROMPT.format(INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
|
||||
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
|
||||
|
||||
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
|
||||
@@ -350,31 +356,32 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
|
||||
input_params_nega={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
|
||||
input_params=("reference_latents", "refer_audio_order_mask"),
|
||||
take_over=True,
|
||||
output_params=("encoder_hidden_states", "encoder_attention_mask"),
|
||||
onload_model_names=("conditioner",),
|
||||
)
|
||||
|
||||
def process(self, pipe, text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, reference_latents, refer_audio_order_mask):
|
||||
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
|
||||
pipe.load_models_to_device(['conditioner'])
|
||||
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
|
||||
text_hidden_states=text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
lyric_hidden_states=lyric_hidden_states,
|
||||
lyric_attention_mask=lyric_attention_mask,
|
||||
reference_latents=reference_latents,
|
||||
refer_audio_order_mask=refer_audio_order_mask,
|
||||
text_hidden_states=inputs_posi.get("text_hidden_states", None),
|
||||
text_attention_mask=inputs_posi.get("text_attention_mask", None),
|
||||
lyric_hidden_states=inputs_posi.get("lyric_hidden_states", None),
|
||||
lyric_attention_mask=inputs_posi.get("lyric_attention_mask", None),
|
||||
reference_latents=inputs_shared.get("reference_latents", None),
|
||||
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
|
||||
)
|
||||
return {"encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask}
|
||||
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
|
||||
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
|
||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states)
|
||||
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("duration", "src_audio"),
|
||||
input_params=("duration", "src_audio", "lm_hints"),
|
||||
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
|
||||
)
|
||||
|
||||
@@ -386,9 +393,15 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
tiled = pipe.silence_latent[0].repeat(repeats, 1)
|
||||
return tiled[:length, :]
|
||||
|
||||
def process(self, pipe, duration, src_audio):
|
||||
if src_audio is not None:
|
||||
raise NotImplementedError("Src audio conditioning is not implemented yet. Please set src_audio to None.")
|
||||
def process(self, pipe, duration, src_audio, lm_hints):
|
||||
if lm_hints is not None:
|
||||
max_latent_length = lm_hints.shape[1]
|
||||
src_latents = lm_hints.clone()
|
||||
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
|
||||
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
|
||||
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
||||
elif src_audio is not None:
|
||||
raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
|
||||
else:
|
||||
max_latent_length = duration * pipe.sample_rate // 1920
|
||||
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
||||
@@ -410,6 +423,7 @@ class AceStepUnit_NoiseInitializer(PipelineUnit):
|
||||
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -427,8 +441,8 @@ class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
||||
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("audio_codes", "seq_len", "silence_latent"),
|
||||
output_params=("lm_hints_25Hz",),
|
||||
input_params=("audio_code_string",),
|
||||
output_params=("lm_hints",),
|
||||
onload_model_names=("tokenizer_model",),
|
||||
)
|
||||
|
||||
@@ -437,45 +451,29 @@ class AceStepUnit_AudioCodeDecoder(PipelineUnit):
|
||||
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
|
||||
if not code_str:
|
||||
return []
|
||||
try:
|
||||
codes = []
|
||||
max_audio_code = 63999
|
||||
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
|
||||
code_value = int(x)
|
||||
codes.append(max(0, min(code_value, max_audio_code)))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid audio_code_string format: {e}")
|
||||
return codes
|
||||
|
||||
def process(self, pipe, audio_codes, seq_len, silence_latent):
|
||||
if audio_codes is None or not audio_codes.strip():
|
||||
return {"lm_hints_25Hz": None}
|
||||
|
||||
code_ids = self._parse_audio_code_string(audio_codes)
|
||||
def process(self, pipe, audio_code_string):
|
||||
if audio_code_string is None or not audio_code_string.strip():
|
||||
return {"lm_hints": None}
|
||||
code_ids = self._parse_audio_code_string(audio_code_string)
|
||||
if len(code_ids) == 0:
|
||||
return {"lm_hints_25Hz": None}
|
||||
return {"lm_hints": None}
|
||||
|
||||
pipe.load_models_to_device(["tokenizer_model"])
|
||||
|
||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
||||
detokenizer = pipe.tokenizer_model.detokenizer
|
||||
|
||||
indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
|
||||
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
|
||||
|
||||
quantized = quantizer.get_output_from_indices(indices) # [1, N, 2048]
|
||||
if quantized.dtype != pipe.torch_dtype:
|
||||
quantized = quantized.to(pipe.torch_dtype)
|
||||
|
||||
lm_hints = detokenizer(quantized) # [1, N*5, 64]
|
||||
|
||||
# Pad or truncate to seq_len
|
||||
current_len = lm_hints.shape[1]
|
||||
if current_len < seq_len:
|
||||
pad_len = seq_len - current_len
|
||||
pad = silence_latent[:, :pad_len, :]
|
||||
lm_hints = torch.cat([lm_hints, pad], dim=1)
|
||||
elif current_len > seq_len:
|
||||
lm_hints = lm_hints[:, :seq_len, :]
|
||||
|
||||
return {"lm_hints_25Hz": lm_hints}
|
||||
quantized = pipe.tokenizer_model.tokenizer.quantizer.get_output_from_indices(indices).to(pipe.torch_dtype) # [1, N, 2048]
|
||||
lm_hints = pipe.tokenizer_model.detokenizer(quantized) # [1, N*5, 64]
|
||||
return {"lm_hints": lm_hints}
|
||||
|
||||
|
||||
def model_fn_ace_step(
|
||||
@@ -499,5 +497,7 @@ def model_fn_ace_step(
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
context_latents=context_latents,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)[0]
|
||||
return decoder_outputs
|
||||
|
||||
@@ -1,33 +1,20 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
import torch
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(
|
||||
model_id="ACE-Step/Ace-Step1.5",
|
||||
origin_file_pattern="acestep-v15-turbo/model.safetensors"
|
||||
),
|
||||
ModelConfig(
|
||||
model_id="ACE-Step/Ace-Step1.5",
|
||||
origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
|
||||
),
|
||||
ModelConfig(
|
||||
model_id="ACE-Step/Ace-Step1.5",
|
||||
origin_file_pattern="vae/diffusion_pytorch_model.safetensors"
|
||||
),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(
|
||||
model_id="ACE-Step/Ace-Step1.5",
|
||||
origin_file_pattern="Qwen3-Embedding-0.6B/"
|
||||
),
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
@@ -41,5 +28,23 @@ audio = pipe(
|
||||
cfg_scale=1.0,
|
||||
)
|
||||
|
||||
save_audio(audio.cpu(), pipe.vae.sampling_rate, "Ace-Step1.5.wav")
|
||||
print(f"Saved to Ace-Step1.5.wav, shape: {audio.shape}, duration: {audio.shape[-1] / pipe.sample_rate:.1f}s")
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
|
||||
|
||||
# input audio codes as reference
|
||||
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
|
||||
audio_code_string = f.read().strip()
|
||||
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
audio_code_string=audio_code_string,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=8,
|
||||
cfg_scale=1.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo5-with-audio-codes.wav")
|
||||
|
||||
@@ -1,52 +1,31 @@
|
||||
"""
|
||||
Ace-Step 1.5 Base (non-turbo, 24 layers) — Text-to-Music inference example.
|
||||
|
||||
Uses cfg_scale=7.0 (standard CFG guidance) and more steps for higher quality.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(
|
||||
model_id="ACE-Step/Ace-Step1.5",
|
||||
origin_file_pattern="acestep-v15-base/model.safetensors"
|
||||
),
|
||||
ModelConfig(
|
||||
model_id="ACE-Step/Ace-Step1.5",
|
||||
origin_file_pattern="acestep-v15-base/model.safetensors"
|
||||
),
|
||||
ModelConfig(
|
||||
model_id="ACE-Step/Ace-Step1.5",
|
||||
origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
|
||||
),
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(
|
||||
model_id="ACE-Step/Ace-Step1.5",
|
||||
origin_file_pattern="Qwen3-Embedding-0.6B/"
|
||||
),
|
||||
vae_config=ModelConfig(
|
||||
model_id="ACE-Step/Ace-Step1.5",
|
||||
origin_file_pattern="vae/"
|
||||
),
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "A cinematic orchestral piece with soaring strings and heroic brass"
|
||||
lyrics = "[Intro - Orchestra]\n\n[Verse 1]\nAcross the mountains, through the valley\nA journey of a thousand miles\n\n[Chorus]\nRise above the stormy skies\nLet the music carry you"
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=30.0,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=20,
|
||||
cfg_scale=7.0, # Base model uses CFG
|
||||
shift=3.0,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
|
||||
sf.write("acestep-v15-base.wav", audio.cpu().numpy(), pipe.sample_rate)
|
||||
print(f"Saved, shape: {audio.shape}")
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base.wav")
|
||||
|
||||
Reference in New Issue
Block a user