diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 5fc95c3..ee97fec 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -916,4 +916,114 @@ joyai_image_series = [ }, ] -MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series +ace_step_series = [ + # === Standard DiT variants (24 layers, hidden_size=2048) === + # Covers: turbo, turbo-shift1, turbo-shift3, turbo-continuous, base, sft + # All share identical state_dict structure → same hash + { + # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors") + "model_hash": "ba29d8bddbb6ace65675f6a757a13c00", + "model_name": "ace_step_dit", + "model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter", + }, + # === XL DiT variants (32 layers, hidden_size=2560) === + # Covers: xl-base, xl-sft, xl-turbo + { + # Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors") + "model_hash": "3a28a410c2246f125153ef792d8bc828", + "model_name": "ace_step_dit", + "model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter", + "extra_kwargs": { + "hidden_size": 2560, + "intermediate_size": 9728, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "head_dim": 128, + "encoder_hidden_size": 2048, + "layer_types": ["sliding_attention", "full_attention"] * 16, + }, + }, + # === Conditioner (shared by all DiT variants, same architecture) === + { + # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors") + "model_hash": "ba29d8bddbb6ace65675f6a757a13c00", + "model_name": "ace_step_conditioner", + "model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter", + }, + # === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) === + { + # Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors") + "model_hash": "3a28a410c2246f125153ef792d8bc828", + "model_name": "ace_step_conditioner", + "model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter", + }, + # === LLM variants === + { + # Example: ModelConfig(model_id="ACE-Step/acestep-5Hz-lm-0.6B", origin_file_pattern="model.safetensors") + "model_hash": "f3ab4bef9e00745fd0fea7aa8b2a4041", + "model_name": "ace_step_lm", + "model_class": "diffsynth.models.ace_step_lm.AceStepLM", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_lm.ace_step_lm_converter", + "extra_kwargs": { + "variant": "acestep-5Hz-lm-0.6B", + }, + }, + { + # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-5Hz-lm-1.7B/model.safetensors") + "model_hash": "a14b6e422b0faa9b41e7efe0fee46766", + "model_name": "ace_step_lm", + "model_class": "diffsynth.models.ace_step_lm.AceStepLM", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_lm.ace_step_lm_converter", + "extra_kwargs": { + "variant": "acestep-5Hz-lm-1.7B", + }, + }, + { + # Example: ModelConfig(model_id="ACE-Step/acestep-5Hz-lm-4B", origin_file_pattern="model-*.safetensors") + "model_hash": "046a3934f2e6f2f6d450bad23b1f4933", + "model_name": "ace_step_lm", + "model_class": "diffsynth.models.ace_step_lm.AceStepLM", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_lm.ace_step_lm_converter", + "extra_kwargs": { + "variant": "acestep-5Hz-lm-4B", + }, + }, + # === Qwen3-Embedding (text encoder) === + { + # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors") + "model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0", + "model_name": "ace_step_text_encoder", + "model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.ace_step_text_encoder_converter", + }, + # === VAE (AutoencoderOobleck CNN) === + { + # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") + "model_hash": "51420834e54474986a7f4be0e4d6f687", + "model_name": "ace_step_vae", + "model_class": "diffsynth.models.ace_step_vae.AceStepVAE", + }, + # === Tokenizer (VAE latent discretization: tokenizer + detokenizer) === + { + # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors") + "model_hash": "ba29d8bddbb6ace65675f6a757a13c00", + "model_name": "ace_step_tokenizer", + "model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter", + }, + # === XL Tokenizer (XL models share same tokenizer architecture) === + { + # Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors") + "model_hash": "3a28a410c2246f125153ef792d8bc828", + "model_name": "ace_step_tokenizer", + "model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter", + }, +] + +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index 6c1b846..f241683 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -4,7 +4,7 @@ from typing_extensions import Literal class FlowMatchScheduler(): - def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image"] = "FLUX.1"): + def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image", "ACE-Step"] = "FLUX.1"): self.set_timesteps_fn = { "FLUX.1": FlowMatchScheduler.set_timesteps_flux, "Wan": FlowMatchScheduler.set_timesteps_wan, @@ -14,6 +14,7 @@ class FlowMatchScheduler(): "LTX-2": FlowMatchScheduler.set_timesteps_ltx2, "Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning, "ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image, + "ACE-Step": FlowMatchScheduler.set_timesteps_ace_step, }.get(template, FlowMatchScheduler.set_timesteps_flux) self.num_train_timesteps = 1000 @@ -142,6 +143,26 @@ class FlowMatchScheduler(): timesteps = sigmas * num_train_timesteps return sigmas, timesteps + @staticmethod + def set_timesteps_ace_step(num_inference_steps=8, denoising_strength=1.0, shift=3.0): + """ACE-Step Flow Matching scheduler. + + Timesteps range from 1.0 to 0.0 (not multiplied by 1000). + Shift transformation: t = shift * t / (1 + (shift - 1) * t) + + Args: + num_inference_steps: Number of diffusion steps. + denoising_strength: Denoising strength (1.0 = full denoising). + shift: Timestep shift parameter (default 3.0 for turbo). + """ + num_train_timesteps = 1000 + sigma_start = denoising_strength + sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps) + if shift is not None and shift != 1.0: + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas # ACE-Step uses [0, 1] range directly + return sigmas, timesteps + @staticmethod def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None): sigma_min = 0.0 diff --git a/diffsynth/models/ace_step_conditioner.py b/diffsynth/models/ace_step_conditioner.py new file mode 100644 index 0000000..93fe0d3 --- /dev/null +++ b/diffsynth/models/ace_step_conditioner.py @@ -0,0 +1,709 @@ +# Copyright 2025 The ACESTEO Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +from einops import rearrange + +from ..core.attention import attention_forward +from ..core.gradient import gradient_checkpoint_forward + +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutput +from transformers.processing_utils import Unpack +from transformers.utils import can_return_tuple, logging +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3MLP, + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + apply_rotary_pos_emb, +) + +logger = logging.get_logger(__name__) + + +def create_4d_mask( + seq_len: int, + dtype: torch.dtype, + device: torch.device, + attention_mask: Optional[torch.Tensor] = None, + sliding_window: Optional[int] = None, + is_sliding_window: bool = False, + is_causal: bool = True, +) -> torch.Tensor: + indices = torch.arange(seq_len, device=device) + diff = indices.unsqueeze(1) - indices.unsqueeze(0) + valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool) + if is_causal: + valid_mask = valid_mask & (diff >= 0) + if is_sliding_window and sliding_window is not None: + if is_causal: + valid_mask = valid_mask & (diff <= sliding_window) + else: + valid_mask = valid_mask & (torch.abs(diff) <= sliding_window) + valid_mask = valid_mask.unsqueeze(0).unsqueeze(0) + if attention_mask is not None: + padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool) + valid_mask = valid_mask & padding_mask_4d + min_dtype = torch.finfo(dtype).min + mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device) + mask_tensor.masked_fill_(valid_mask, 0.0) + return mask_tensor + + +def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor): + hidden_cat = torch.cat([hidden1, hidden2], dim=1) + mask_cat = torch.cat([mask1, mask2], dim=1) + B, L, D = hidden_cat.shape + sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) + hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D)) + lengths = mask_cat.sum(dim=1) + new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1)) + return hidden_left, new_mask + + +class Lambda(nn.Module): + def __init__(self, func): + super().__init__() + self.func = func + + def forward(self, x): + return self.func(x) + + +class AceStepAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + rms_norm_eps: float, + attention_bias: bool, + attention_dropout: float, + layer_types: list, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, + layer_idx: int = 0, + is_cross_attention: bool = False, + is_causal: bool = False, + ): + super().__init__() + self.layer_idx = layer_idx + self.head_dim = head_dim or hidden_size // num_attention_heads + self.num_key_value_groups = num_attention_heads // num_key_value_heads + self.scaling = self.head_dim ** -0.5 + self.attention_dropout = attention_dropout + if is_cross_attention: + is_causal = False + self.is_causal = is_causal + self.is_cross_attention = is_cross_attention + + self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias) + self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias) + self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias) + self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps) + self.attention_type = layer_types[layer_idx] + self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] = None, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + + is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None + + if is_cross_attention: + encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + curr_past_key_value = past_key_value.cross_attention_cache + if not is_updated: + key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) + past_key_value.is_updated[self.layer_idx] = True + else: + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + + else: + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if self.num_key_value_groups > 1: + key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2) + value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2) + + attn_output = attention_forward( + query_states, key_states, value_states, + q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", + attn_mask=attention_mask, + ) + attn_weights = None + + attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class AceStepEncoderLayer(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_attention_heads: int, + num_key_value_heads: int, + rms_norm_eps: float, + attention_bias: bool, + attention_dropout: float, + layer_types: list, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, + layer_idx: int = 0, + ): + super().__init__() + self.hidden_size = hidden_size + self.layer_idx = layer_idx + + self.self_attn = AceStepAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + layer_idx=layer_idx, + is_cross_attention=False, + is_causal=False, + ) + self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + + mlp_config = type('Config', (), { + 'hidden_size': hidden_size, + 'intermediate_size': intermediate_size, + 'hidden_act': 'silu', + })() + self.mlp = Qwen3MLP(mlp_config) + self.attention_type = layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + use_cache=False, + past_key_value=None, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + return outputs + + +class AceStepLyricEncoder(nn.Module): + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + num_hidden_layers: int = 24, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + attention_dropout: float = 0.0, + layer_types: Optional[list] = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = 128, + use_sliding_window: bool = True, + use_cache: bool = True, + rope_theta: float = 1000000, + max_position_embeddings: int = 32768, + initializer_range: float = 0.02, + text_hidden_dim: int = 1024, + num_lyric_encoder_hidden_layers: int = 8, + **kwargs, + ): + super().__init__() + self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers + self.text_hidden_dim = text_hidden_dim + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2)) + self.head_dim = head_dim or hidden_size // num_attention_heads + self.sliding_window = sliding_window + self.use_sliding_window = use_sliding_window + self.use_cache = use_cache + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self._attn_implementation = kwargs.get("_attn_implementation", "sdpa") + + self.embed_tokens = nn.Linear(text_hidden_dim, hidden_size) + self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + rope_config = type('RopeConfig', (), { + 'hidden_size': hidden_size, + 'num_attention_heads': num_attention_heads, + 'num_key_value_heads': num_key_value_heads, + 'head_dim': head_dim, + 'max_position_embeddings': max_position_embeddings, + 'rope_theta': rope_theta, + 'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta}, + 'rms_norm_eps': rms_norm_eps, + 'attention_bias': attention_bias, + 'attention_dropout': attention_dropout, + 'hidden_act': 'silu', + 'intermediate_size': intermediate_size, + 'layer_types': self.layer_types, + 'sliding_window': sliding_window, + '_attn_implementation': self._attn_implementation, + })() + self.rotary_emb = Qwen3RotaryEmbedding(rope_config) + self.gradient_checkpointing = False + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=self.layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + layer_idx=layer_idx, + ) + for layer_idx in range(num_lyric_encoder_hidden_layers) + ]) + + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutput: + output_attentions = output_attentions if output_attentions is not None else False + output_hidden_states = output_hidden_states if output_hidden_states is not None else False + + assert input_ids is None, "Only `inputs_embeds` is supported for the lyric encoder." + assert attention_mask is not None, "Attention mask must be provided for the lyric encoder." + assert inputs_embeds is not None, "Inputs embeddings must be provided for the lyric encoder." + + inputs_embeds = self.embed_tokens(inputs_embeds) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + seq_len = inputs_embeds.shape[1] + dtype = inputs_embeds.dtype + device = inputs_embeds.device + + full_attn_mask = create_4d_mask( + seq_len=seq_len, dtype=dtype, device=device, + attention_mask=attention_mask, sliding_window=None, + is_sliding_window=False, is_causal=False + ) + sliding_attn_mask = None + if self.use_sliding_window: + sliding_attn_mask = create_4d_mask( + seq_len=seq_len, dtype=dtype, device=device, + attention_mask=attention_mask, sliding_window=self.sliding_window, + is_sliding_window=True, is_causal=False + ) + + self_attn_mask_mapping = { + "full_attention": full_attn_mask, + "sliding_attention": sliding_attn_mask, + } + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for layer_module in self.layers[: self.num_lyric_encoder_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer_module( + hidden_states, position_embeddings, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, output_attentions, + **flash_attn_kwargs, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class AceStepTimbreEncoder(nn.Module): + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + num_hidden_layers: int = 24, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + attention_dropout: float = 0.0, + layer_types: Optional[list] = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = 128, + use_sliding_window: bool = True, + use_cache: bool = True, + rope_theta: float = 1000000, + max_position_embeddings: int = 32768, + initializer_range: float = 0.02, + timbre_hidden_dim: int = 64, + num_timbre_encoder_hidden_layers: int = 4, + **kwargs, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2)) + self.head_dim = head_dim or hidden_size // num_attention_heads + self.sliding_window = sliding_window + self.use_sliding_window = use_sliding_window + self.use_cache = use_cache + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.timbre_hidden_dim = timbre_hidden_dim + self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers + self._attn_implementation = kwargs.get("_attn_implementation", "sdpa") + + self.embed_tokens = nn.Linear(timbre_hidden_dim, hidden_size) + self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + rope_config = type('RopeConfig', (), { + 'hidden_size': hidden_size, + 'num_attention_heads': num_attention_heads, + 'num_key_value_heads': num_key_value_heads, + 'head_dim': head_dim, + 'max_position_embeddings': max_position_embeddings, + 'rope_theta': rope_theta, + 'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta}, + 'rms_norm_eps': rms_norm_eps, + 'attention_bias': attention_bias, + 'attention_dropout': attention_dropout, + 'hidden_act': 'silu', + 'intermediate_size': intermediate_size, + 'layer_types': self.layer_types, + 'sliding_window': sliding_window, + '_attn_implementation': self._attn_implementation, + })() + self.rotary_emb = Qwen3RotaryEmbedding(rope_config) + self.gradient_checkpointing = False + self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size)) + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=self.layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + layer_idx=layer_idx, + ) + for layer_idx in range(num_timbre_encoder_hidden_layers) + ]) + + + def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask): + N, d = timbre_embs_packed.shape + device = timbre_embs_packed.device + dtype = timbre_embs_packed.dtype + B = int(refer_audio_order_mask.max().item() + 1) + counts = torch.bincount(refer_audio_order_mask, minlength=B) + max_count = counts.max().item() + sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True) + sorted_batch_ids = refer_audio_order_mask[sorted_indices] + positions = torch.arange(N, device=device) + batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]]) + positions_in_sorted = positions - batch_starts[sorted_batch_ids] + inverse_indices = torch.empty_like(sorted_indices) + inverse_indices[sorted_indices] = torch.arange(N, device=device) + positions_in_batch = positions_in_sorted[inverse_indices] + indices_2d = refer_audio_order_mask * max_count + positions_in_batch + one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype) + timbre_embs_flat = one_hot.t() @ timbre_embs_packed + timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d) + mask_flat = (one_hot.sum(dim=0) > 0).long() + new_mask = mask_flat.reshape(B, max_count) + return timbre_embs_unpack, new_mask + + @can_return_tuple + def forward( + self, + refer_audio_acoustic_hidden_states_packed: Optional[torch.FloatTensor] = None, + refer_audio_order_mask: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutput: + inputs_embeds = refer_audio_acoustic_hidden_states_packed + inputs_embeds = self.embed_tokens(inputs_embeds) + # Handle 2D (packed) or 3D (batched) input + is_packed = inputs_embeds.dim() == 2 + if is_packed: + seq_len = inputs_embeds.shape[0] + cache_position = torch.arange(0, seq_len, device=inputs_embeds.device) + position_ids = cache_position.unsqueeze(0) + inputs_embeds = inputs_embeds.unsqueeze(0) + else: + seq_len = inputs_embeds.shape[1] + cache_position = torch.arange(0, seq_len, device=inputs_embeds.device) + position_ids = cache_position.unsqueeze(0) + + dtype = inputs_embeds.dtype + device = inputs_embeds.device + + full_attn_mask = create_4d_mask( + seq_len=seq_len, dtype=dtype, device=device, + attention_mask=attention_mask, sliding_window=None, + is_sliding_window=False, is_causal=False + ) + sliding_attn_mask = None + if self.use_sliding_window: + sliding_attn_mask = create_4d_mask( + seq_len=seq_len, dtype=dtype, device=device, + attention_mask=attention_mask, sliding_window=self.sliding_window, + is_sliding_window=True, is_causal=False + ) + + self_attn_mask_mapping = { + "full_attention": full_attn_mask, + "sliding_attention": sliding_attn_mask, + } + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer_module in self.layers[: self.num_timbre_encoder_hidden_layers]: + layer_outputs = layer_module( + hidden_states, position_embeddings, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + **flash_attn_kwargs, + ) + hidden_states = layer_outputs[0] + + hidden_states = self.norm(hidden_states) + # For packed input: reshape [1, T, D] -> [T, D] for unpacking + if is_packed: + hidden_states = hidden_states.squeeze(0) + timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask) + return timbre_embs_unpack, timbre_embs_mask + + +class AceStepConditionEncoder(nn.Module): + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + num_hidden_layers: int = 24, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + attention_dropout: float = 0.0, + layer_types: Optional[list] = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = 128, + use_sliding_window: bool = True, + use_cache: bool = True, + rope_theta: float = 1000000, + max_position_embeddings: int = 32768, + initializer_range: float = 0.02, + text_hidden_dim: int = 1024, + timbre_hidden_dim: int = 64, + num_lyric_encoder_hidden_layers: int = 8, + num_timbre_encoder_hidden_layers: int = 4, + **kwargs, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2)) + self.head_dim = head_dim or hidden_size // num_attention_heads + self.sliding_window = sliding_window + self.use_sliding_window = use_sliding_window + self.use_cache = use_cache + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.text_hidden_dim = text_hidden_dim + self.timbre_hidden_dim = timbre_hidden_dim + self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers + self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers + self._attn_implementation = kwargs.get("_attn_implementation", "sdpa") + + self.text_projector = nn.Linear(text_hidden_dim, hidden_size, bias=False) + self.null_condition_emb = nn.Parameter(torch.randn(1, 1, hidden_size)) + self.lyric_encoder = AceStepLyricEncoder( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + use_sliding_window=use_sliding_window, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + text_hidden_dim=text_hidden_dim, + num_lyric_encoder_hidden_layers=num_lyric_encoder_hidden_layers, + ) + self.timbre_encoder = AceStepTimbreEncoder( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + use_sliding_window=use_sliding_window, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + timbre_hidden_dim=timbre_hidden_dim, + num_timbre_encoder_hidden_layers=num_timbre_encoder_hidden_layers, + ) + + def forward( + self, + text_hidden_states: Optional[torch.FloatTensor] = None, + text_attention_mask: Optional[torch.Tensor] = None, + lyric_hidden_states: Optional[torch.LongTensor] = None, + lyric_attention_mask: Optional[torch.Tensor] = None, + refer_audio_acoustic_hidden_states_packed: Optional[torch.Tensor] = None, + refer_audio_order_mask: Optional[torch.LongTensor] = None, + ): + text_hidden_states = self.text_projector(text_hidden_states) + lyric_encoder_outputs = self.lyric_encoder( + inputs_embeds=lyric_hidden_states, + attention_mask=lyric_attention_mask, + ) + lyric_hidden_states = lyric_encoder_outputs.last_hidden_state + timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder( + refer_audio_acoustic_hidden_states_packed, + refer_audio_order_mask + ) + + encoder_hidden_states, encoder_attention_mask = pack_sequences( + lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask + ) + encoder_hidden_states, encoder_attention_mask = pack_sequences( + encoder_hidden_states, text_hidden_states, encoder_attention_mask, text_attention_mask + ) + return encoder_hidden_states, encoder_attention_mask diff --git a/diffsynth/models/ace_step_dit.py b/diffsynth/models/ace_step_dit.py new file mode 100644 index 0000000..c4621fe --- /dev/null +++ b/diffsynth/models/ace_step_dit.py @@ -0,0 +1,908 @@ +# Copyright 2025 The ACESTEO Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ..core.attention.attention import attention_forward +from ..core import gradient_checkpoint_forward + +from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutput +from transformers.processing_utils import Unpack +from transformers.utils import logging + +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3MLP, + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + apply_rotary_pos_emb, +) + +logger = logging.get_logger(__name__) + + +def create_4d_mask( + seq_len: int, + dtype: torch.dtype, + device: torch.device, + attention_mask: Optional[torch.Tensor] = None, # [Batch, Seq_Len] + sliding_window: Optional[int] = None, + is_sliding_window: bool = False, + is_causal: bool = True, +) -> torch.Tensor: + """ + General 4D Attention Mask generator compatible with CPU/Mac/SDPA and Eager mode. + Supports use cases: + 1. Causal Full: is_causal=True, is_sliding_window=False (standard GPT) + 2. Causal Sliding: is_causal=True, is_sliding_window=True (Mistral/Qwen local window) + 3. Bidirectional Full: is_causal=False, is_sliding_window=False (BERT/Encoder) + 4. Bidirectional Sliding: is_causal=False, is_sliding_window=True (Longformer local) + + Returns: + [Batch, 1, Seq_Len, Seq_Len] additive mask (0.0 for keep, -inf for mask) + """ + # ------------------------------------------------------ + # 1. Construct basic geometry mask [Seq_Len, Seq_Len] + # ------------------------------------------------------ + + # Build index matrices + # i (Query): [0, 1, ..., L-1] + # j (Key): [0, 1, ..., L-1] + indices = torch.arange(seq_len, device=device) + # diff = i - j + diff = indices.unsqueeze(1) - indices.unsqueeze(0) + + # Initialize all True (all positions visible) + valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool) + + # (A) Handle causality (Causal) + if is_causal: + # i >= j => diff >= 0 + valid_mask = valid_mask & (diff >= 0) + + # (B) Handle sliding window + if is_sliding_window and sliding_window is not None: + if is_causal: + # Causal sliding: only attend to past window steps + # i - j <= window => diff <= window + # (diff >= 0 already handled above) + valid_mask = valid_mask & (diff <= sliding_window) + else: + # Bidirectional sliding: attend past and future window steps + # |i - j| <= window => abs(diff) <= sliding_window + valid_mask = valid_mask & (torch.abs(diff) <= sliding_window) + + # Expand dimensions to [1, 1, Seq_Len, Seq_Len] for broadcasting + valid_mask = valid_mask.unsqueeze(0).unsqueeze(0) + + # ------------------------------------------------------ + # 2. Apply padding mask (Key Masking) + # ------------------------------------------------------ + if attention_mask is not None: + # attention_mask shape: [Batch, Seq_Len] (1=valid, 0=padding) + # We want to mask out invalid keys (columns) + # Expand shape: [Batch, 1, 1, Seq_Len] + padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool) + + # Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L] + # Result shape: [B, 1, L, L] + valid_mask = valid_mask & padding_mask_4d + + # ------------------------------------------------------ + # 3. Convert to additive mask + # ------------------------------------------------------ + # Get the minimal value for current dtype + min_dtype = torch.finfo(dtype).min + + # Create result tensor filled with -inf by default + mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device) + + # Set valid positions to 0.0 + mask_tensor.masked_fill_(valid_mask, 0.0) + + return mask_tensor + + +def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor): + """ + Pack two sequences by concatenating and sorting them based on mask values. + + Args: + hidden1: First hidden states tensor of shape [B, L1, D] + hidden2: Second hidden states tensor of shape [B, L2, D] + mask1: First mask tensor of shape [B, L1] + mask2: Second mask tensor of shape [B, L2] + + Returns: + Tuple of (packed_hidden_states, new_mask) where: + - packed_hidden_states: Packed hidden states with valid tokens (mask=1) first, shape [B, L1+L2, D] + - new_mask: New mask tensor indicating valid positions, shape [B, L1+L2] + """ + # Step 1: Concatenate hidden states and masks along sequence dimension + hidden_cat = torch.cat([hidden1, hidden2], dim=1) # [B, L, D] + mask_cat = torch.cat([mask1, mask2], dim=1) # [B, L] + + B, L, D = hidden_cat.shape + + # Step 2: Sort indices so that mask values of 1 come before 0 + sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) # [B, L] + + # Step 3: Reorder hidden states using sorted indices + hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D)) + + # Step 4: Create new mask based on valid sequence lengths + lengths = mask_cat.sum(dim=1) # [B] + new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1)) + + return hidden_left, new_mask + + +class TimestepEmbedding(nn.Module): + """ + Timestep embedding module for diffusion models. + + Converts timestep values into high-dimensional embeddings using sinusoidal + positional encoding, followed by MLP layers. Used for conditioning diffusion + models on timestep information. + """ + def __init__( + self, + in_channels: int, + time_embed_dim: int, + scale: float = 1000, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True) + self.act1 = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True) + self.in_channels = in_channels + + self.act2 = nn.SiLU() + self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6) + self.scale = scale + + def timestep_embedding(self, t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: A 1-D tensor of N indices, one per batch element. These may be fractional. + dim: The dimension of the output embeddings. + max_period: Controls the minimum frequency of the embeddings. + + Returns: + An (N, D) tensor of positional embeddings. + """ + t = t * self.scale + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.in_channels) + temb = self.linear_1(t_freq.to(t.dtype)) + temb = self.act1(temb) + temb = self.linear_2(temb) + timestep_proj = self.time_proj(self.act2(temb)).unflatten(1, (6, -1)) + return temb, timestep_proj + + +class AceStepAttention(nn.Module): + """ + Multi-headed attention module for AceStep model. + + Implements the attention mechanism from 'Attention Is All You Need' paper, + with support for both self-attention and cross-attention modes. Uses RMSNorm + for query and key normalization, and supports sliding window attention for + efficient long-sequence processing. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + rms_norm_eps: float, + attention_bias: bool, + attention_dropout: float, + layer_types: list, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, + layer_idx: int = 0, + is_cross_attention: bool = False, + is_causal: bool = False, + ): + super().__init__() + self.layer_idx = layer_idx + self.head_dim = head_dim or hidden_size // num_attention_heads + self.num_key_value_groups = num_attention_heads // num_key_value_heads + self.scaling = self.head_dim ** -0.5 + self.attention_dropout = attention_dropout + if is_cross_attention: + is_causal = False + self.is_causal = is_causal + self.is_cross_attention = is_cross_attention + + self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias) + self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias) + self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias) + self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps) + self.attention_type = layer_types[layer_idx] + self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] = None, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + # Project and normalize query states + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + + # Determine if this is cross-attention (requires encoder_hidden_states) + is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None + + # Cross-attention path: attend to encoder hidden states + if is_cross_attention: + encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + # After the first generated token, we can reuse all key/value states from cache + curr_past_key_value = past_key_value.cross_attention_cache + + # Conditions for calculating key and value states + if not is_updated: + # Compute and cache K/V for the first time + key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + # Update cache: save all key/value states to cache for fast auto-regressive generation + key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) + # Set flag that this layer's cross-attention cache is updated + past_key_value.is_updated[self.layer_idx] = True + else: + # Reuse cached key/value states for subsequent tokens + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + # No cache used, compute K/V directly + key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + + # Self-attention path: attend to the same sequence + else: + # Project and normalize key/value states for self-attention + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + # Apply rotary position embeddings (RoPE) if provided + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # Update cache for auto-regressive generation + if past_key_value is not None: + # Sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # GGA expansion: if num_key_value_heads < num_attention_heads + if self.num_key_value_groups > 1: + key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2) + value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2) + + # Use DiffSynth unified attention + # Tensors are already in (batch, heads, seq, dim) format -> "b n s d" + attn_output = attention_forward( + query_states, key_states, value_states, + q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", + attn_mask=attention_mask, + ) + + attn_weights = None # attention_forward doesn't return weights + + # Flatten and project output: (B, n_heads, seq, dim) -> (B, seq, n_heads*dim) + attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class AceStepEncoderLayer(nn.Module): + """ + Encoder layer for AceStep model. + + Consists of self-attention and MLP (feed-forward) sub-layers with residual connections. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + intermediate_size: int = 6144, + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + attention_dropout: float = 0.0, + layer_types: list = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, + layer_idx: int = 0, + ): + super().__init__() + self.hidden_size = hidden_size + self.layer_idx = layer_idx + + self.self_attn = AceStepAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + layer_idx=layer_idx, + is_cross_attention=False, + is_causal=False, + ) + self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + + # MLP (feed-forward) sub-layer + self.mlp = Qwen3MLP( + config=type('Config', (), { + 'hidden_size': hidden_size, + 'intermediate_size': intermediate_size, + 'hidden_act': 'silu', + })() + ) + self.attention_type = layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[ + torch.FloatTensor, + Optional[tuple[torch.FloatTensor, torch.FloatTensor]], + ]: + # Self-attention with residual connection + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + # Encoders don't use cache + use_cache=False, + past_key_value=None, + **kwargs, + ) + hidden_states = residual + hidden_states + + # MLP with residual connection + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class AceStepDiTLayer(nn.Module): + """ + DiT (Diffusion Transformer) layer for AceStep model. + + Implements a transformer layer with three main components: + 1. Self-attention with adaptive layer norm (AdaLN) + 2. Cross-attention (optional) for conditioning on encoder outputs + 3. Feed-forward MLP with adaptive layer norm + + Uses scale-shift modulation from timestep embeddings for adaptive normalization. + """ + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + intermediate_size: int, + rms_norm_eps: float, + attention_bias: bool, + attention_dropout: float, + layer_types: list, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, + layer_idx: int = 0, + use_cross_attention: bool = True, + ): + super().__init__() + + self.self_attn_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + self.self_attn = AceStepAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + layer_idx=layer_idx, + ) + + self.use_cross_attention = use_cross_attention + if self.use_cross_attention: + self.cross_attn_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + self.cross_attn = AceStepAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + layer_idx=layer_idx, + is_cross_attention=True, + ) + + self.mlp_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + self.mlp = Qwen3MLP( + config=type('Config', (), { + 'hidden_size': hidden_size, + 'intermediate_size': intermediate_size, + 'hidden_act': 'silu', + })() + ) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, hidden_size) / hidden_size**0.5) + self.attention_type = layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + + # Extract scale-shift parameters for adaptive layer norm from timestep embeddings + # 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb + ).chunk(6, dim=1) + + # Step 1: Self-attention with adaptive layer norm (AdaLN) + # Apply adaptive normalization: norm(x) * (1 + scale) + shift + norm_hidden_states = (self.self_attn_norm(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output, self_attn_weights = self.self_attn( + hidden_states=norm_hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + use_cache=False, + past_key_value=None, + **kwargs, + ) + # Apply gated residual connection: x = x + attn_output * gate + hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states) + + # Step 2: Cross-attention (if enabled) for conditioning on encoder outputs + if self.use_cross_attention: + norm_hidden_states = self.cross_attn_norm(hidden_states).type_as(hidden_states) + attn_output, cross_attn_weights = self.cross_attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + # Standard residual connection for cross-attention + hidden_states = hidden_states + attn_output + + # Step 3: Feed-forward (MLP) with adaptive layer norm + # Apply adaptive normalization for MLP: norm(x) * (1 + scale) + shift + norm_hidden_states = (self.mlp_norm(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states) + ff_output = self.mlp(norm_hidden_states) + # Apply gated residual connection: x = x + mlp_output * gate + hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states) + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + + +class Lambda(nn.Module): + """ + Wrapper module for arbitrary lambda functions. + + Allows using lambda functions in nn.Sequential by wrapping them in a Module. + Useful for simple transformations like transpose operations. + """ + def __init__(self, func): + super().__init__() + self.func = func + + def forward(self, x): + return self.func(x) + + +class AceStepDiTModel(nn.Module): + """ + DiT (Diffusion Transformer) model for AceStep. + + Main diffusion model that generates audio latents conditioned on text, lyrics, + and timbre. Uses patch-based processing with transformer layers, timestep + conditioning, and cross-attention to encoder outputs. + """ + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + num_hidden_layers: int = 24, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + attention_dropout: float = 0.0, + layer_types: Optional[list] = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = 128, + use_sliding_window: bool = True, + use_cache: bool = True, + rope_theta: float = 1000000, + max_position_embeddings: int = 32768, + initializer_range: float = 0.02, + patch_size: int = 2, + in_channels: int = 192, + audio_acoustic_hidden_dim: int = 64, + encoder_hidden_size: Optional[int] = None, + **kwargs, + ): + super().__init__() + + self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2)) + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.use_cache = use_cache + encoder_hidden_size = encoder_hidden_size or hidden_size + + # Rotary position embeddings for transformer layers + rope_config = type('RopeConfig', (), { + 'hidden_size': hidden_size, + 'num_attention_heads': num_attention_heads, + 'num_key_value_heads': num_key_value_heads, + 'head_dim': head_dim, + 'max_position_embeddings': max_position_embeddings, + 'rope_theta': rope_theta, + 'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta}, + 'rms_norm_eps': rms_norm_eps, + 'attention_bias': attention_bias, + 'attention_dropout': attention_dropout, + 'hidden_act': 'silu', + 'intermediate_size': intermediate_size, + 'layer_types': self.layer_types, + 'sliding_window': sliding_window, + })() + self.rotary_emb = Qwen3RotaryEmbedding(rope_config) + + # Stack of DiT transformer layers + self.layers = nn.ModuleList([ + AceStepDiTLayer( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + intermediate_size=intermediate_size, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=self.layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + layer_idx=layer_idx, + ) + for layer_idx in range(num_hidden_layers) + ]) + + self.patch_size = patch_size + + # Input projection: patch embedding using 1D convolution + self.proj_in = nn.Sequential( + Lambda(lambda x: x.transpose(1, 2)), + nn.Conv1d( + in_channels=in_channels, + out_channels=hidden_size, + kernel_size=patch_size, + stride=patch_size, + padding=0, + ), + Lambda(lambda x: x.transpose(1, 2)), + ) + + # Timestep embeddings for diffusion conditioning + self.time_embed = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size) + self.time_embed_r = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size) + + # Project encoder hidden states to model dimension + self.condition_embedder = nn.Linear(encoder_hidden_size, hidden_size, bias=True) + + # Output normalization and projection + self.norm_out = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + self.proj_out = nn.Sequential( + Lambda(lambda x: x.transpose(1, 2)), + nn.ConvTranspose1d( + in_channels=hidden_size, + out_channels=audio_acoustic_hidden_dim, + kernel_size=patch_size, + stride=patch_size, + padding=0, + ), + Lambda(lambda x: x.transpose(1, 2)), + ) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, hidden_size) / hidden_size**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + timestep_r: torch.Tensor, + attention_mask: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + context_latents: torch.Tensor, + use_cache: Optional[bool] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + cache_position: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + return_hidden_states: int = None, + custom_layers_config: Optional[dict] = None, + enable_early_exit: bool = False, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ): + + use_cache = use_cache if use_cache is not None else self.use_cache + + # Disable cache during training or when gradient checkpointing is enabled + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + if self.training: + use_cache = False + + # Initialize cache if needed (only during inference for auto-regressive generation) + if not self.training and use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + # Compute timestep embeddings for diffusion conditioning + # Two embeddings: one for timestep t, one for timestep difference (t - r) + temb_t, timestep_proj_t = self.time_embed(timestep) + temb_r, timestep_proj_r = self.time_embed_r(timestep - timestep_r) + # Combine embeddings + temb = temb_t + temb_r + timestep_proj = timestep_proj_t + timestep_proj_r + + # Concatenate context latents (source latents + chunk masks) with hidden states + hidden_states = torch.cat([context_latents, hidden_states], dim=-1) + # Record original sequence length for later restoration after padding + original_seq_len = hidden_states.shape[1] + # Apply padding if sequence length is not divisible by patch_size + # This ensures proper patch extraction + pad_length = 0 + if hidden_states.shape[1] % self.patch_size != 0: + pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size) + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_length), mode='constant', value=0) + + # Project input to patches and project encoder states + hidden_states = self.proj_in(hidden_states) + encoder_hidden_states = self.condition_embedder(encoder_hidden_states) + + # Cache positions + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device + ) + + # Position IDs + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + seq_len = hidden_states.shape[1] + encoder_seq_len = encoder_hidden_states.shape[1] + dtype = hidden_states.dtype + device = hidden_states.device + + # Initialize Mask variables + full_attn_mask = None + sliding_attn_mask = None + encoder_attn_mask = None + decoder_attn_mask = None + # Target library discards the passed-in attention_mask for 4D mask + # construction (line 1384: attention_mask = None) + attention_mask = None + + # 1. Full Attention (Bidirectional, Global) + full_attn_mask = create_4d_mask( + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=attention_mask, + sliding_window=None, + is_sliding_window=False, + is_causal=False + ) + max_len = max(seq_len, encoder_seq_len) + + encoder_attn_mask = create_4d_mask( + seq_len=max_len, + dtype=dtype, + device=device, + attention_mask=attention_mask, + sliding_window=None, + is_sliding_window=False, + is_causal=False + ) + encoder_attn_mask = encoder_attn_mask[:, :, :seq_len, :encoder_seq_len] + + # 2. Sliding Attention (Bidirectional, Local) + if self.use_sliding_window: + sliding_attn_mask = create_4d_mask( + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=attention_mask, + sliding_window=self.sliding_window, + is_sliding_window=True, + is_causal=False + ) + + # Build mask mapping + self_attn_mask_mapping = { + "full_attention": full_attn_mask, + "sliding_attention": sliding_attn_mask, + "encoder_attention_mask": encoder_attn_mask, + } + + # Create position embeddings to be shared across all decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + all_cross_attentions = () if output_attentions else None + + # Handle early exit for custom layer configurations + max_needed_layer = float('inf') + if custom_layers_config is not None and enable_early_exit: + max_needed_layer = max(custom_layers_config.keys()) + output_attentions = True + if all_cross_attentions is None: + all_cross_attentions = () + + # Process through transformer layers + for index_block, layer_module in enumerate(self.layers): + # Early exit optimization + if index_block > max_needed_layer: + break + + # Prepare layer arguments + layer_args = ( + hidden_states, + position_embeddings, + timestep_proj, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + encoder_hidden_states, + self_attn_mask_mapping["encoder_attention_mask"], + ) + layer_kwargs = flash_attn_kwargs + + # Use gradient checkpointing if enabled + if use_gradient_checkpointing or use_gradient_checkpointing_offload: + layer_outputs = gradient_checkpoint_forward( + layer_module, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + *layer_args, + **layer_kwargs, + ) + else: + layer_outputs = layer_module( + *layer_args, + **layer_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions and self.layers[index_block].use_cross_attention: + # layer_outputs structure: (hidden_states, self_attn_weights, cross_attn_weights) + if len(layer_outputs) >= 3: + all_cross_attentions += (layer_outputs[2],) + + if return_hidden_states: + return hidden_states + + # Extract scale-shift parameters for adaptive output normalization + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + # Apply adaptive layer norm: norm(x) * (1 + scale) + shift + hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states) + # Project output: de-patchify back to original sequence format + hidden_states = self.proj_out(hidden_states) + + # Crop back to original sequence length to ensure exact length match (remove padding) + hidden_states = hidden_states[:, :original_seq_len, :] + + outputs = (hidden_states, past_key_values) + + if output_attentions: + outputs += (all_cross_attentions,) + return outputs diff --git a/diffsynth/models/ace_step_lm.py b/diffsynth/models/ace_step_lm.py new file mode 100644 index 0000000..fc8c081 --- /dev/null +++ b/diffsynth/models/ace_step_lm.py @@ -0,0 +1,79 @@ +import torch + + +LM_CONFIGS = { + "acestep-5Hz-lm-0.6B": { + "hidden_size": 1024, + "intermediate_size": 3072, + "num_hidden_layers": 28, + "num_attention_heads": 16, + "layer_types": ["full_attention"] * 28, + "max_window_layers": 28, + }, + "acestep-5Hz-lm-1.7B": { + "hidden_size": 2048, + "intermediate_size": 6144, + "num_hidden_layers": 28, + "num_attention_heads": 16, + "layer_types": ["full_attention"] * 28, + "max_window_layers": 28, + }, + "acestep-5Hz-lm-4B": { + "hidden_size": 2560, + "intermediate_size": 9728, + "num_hidden_layers": 36, + "num_attention_heads": 32, + "layer_types": ["full_attention"] * 36, + "max_window_layers": 36, + }, +} + + +class AceStepLM(torch.nn.Module): + """ + Language model for ACE-Step. + + Converts natural language prompts into structured parameters + (caption, lyrics, bpm, keyscale, duration, timesignature, etc.) + for ACE-Step music generation. + + Wraps a Qwen3ForCausalLM transformers model. Config is manually + constructed based on variant type, and model weights are loaded + via DiffSynth's standard mechanism from safetensors files. + """ + + def __init__( + self, + variant: str = "acestep-5Hz-lm-1.7B", + ): + super().__init__() + from transformers import Qwen3Config, Qwen3ForCausalLM + + config_params = LM_CONFIGS[variant] + + config = Qwen3Config( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=151643, + dtype="bfloat16", + eos_token_id=151645, + head_dim=128, + hidden_act="silu", + initializer_range=0.02, + max_position_embeddings=40960, + model_type="qwen3", + num_key_value_heads=8, + pad_token_id=151643, + rms_norm_eps=1e-06, + rope_scaling=None, + rope_theta=1000000, + sliding_window=None, + tie_word_embeddings=True, + use_cache=True, + use_sliding_window=False, + vocab_size=217204, + **config_params, + ) + + self.model = Qwen3ForCausalLM(config) + self.config = config diff --git a/diffsynth/models/ace_step_text_encoder.py b/diffsynth/models/ace_step_text_encoder.py new file mode 100644 index 0000000..58b52a7 --- /dev/null +++ b/diffsynth/models/ace_step_text_encoder.py @@ -0,0 +1,80 @@ +import torch + + +class AceStepTextEncoder(torch.nn.Module): + """ + Text encoder for ACE-Step using Qwen3-Embedding-0.6B. + + Converts text/lyric tokens to hidden state embeddings that are + further processed by the ACE-Step ConditionEncoder. + + Wraps a Qwen3Model transformers model. Config is manually + constructed, and model weights are loaded via DiffSynth's + standard mechanism from safetensors files. + """ + + def __init__( + self, + ): + super().__init__() + from transformers import Qwen3Config, Qwen3Model + + config = Qwen3Config( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=151643, + dtype="bfloat16", + eos_token_id=151643, + head_dim=128, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=3072, + layer_types=["full_attention"] * 28, + max_position_embeddings=32768, + max_window_layers=28, + model_type="qwen3", + num_attention_heads=16, + num_hidden_layers=28, + num_key_value_heads=8, + pad_token_id=151643, + rms_norm_eps=1e-06, + rope_scaling=None, + rope_theta=1000000, + sliding_window=None, + tie_word_embeddings=True, + use_cache=True, + use_sliding_window=False, + vocab_size=151669, + ) + + self.model = Qwen3Model(config) + self.config = config + self.hidden_size = config.hidden_size + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor, + ): + """ + Encode text/lyric tokens to hidden states. + + Args: + input_ids: [B, T] token IDs + attention_mask: [B, T] attention mask + + Returns: + last_hidden_state: [B, T, hidden_size] + """ + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + ) + return outputs.last_hidden_state + + def to(self, *args, **kwargs): + self.model.to(*args, **kwargs) + return self diff --git a/diffsynth/models/ace_step_tokenizer.py b/diffsynth/models/ace_step_tokenizer.py new file mode 100644 index 0000000..c01e9d5 --- /dev/null +++ b/diffsynth/models/ace_step_tokenizer.py @@ -0,0 +1,732 @@ +# Copyright 2025 The ACESTEO Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ACE-Step Audio Tokenizer — VAE latent discretization pathway. + +Contains: +- AceStepAudioTokenizer: continuous VAE latent → discrete FSQ tokens +- AudioTokenDetokenizer: discrete tokens → continuous VAE-latent-shaped features + +Only used in cover song mode (is_covers=True). Bypassed in text-to-music. +""" +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange + +from ..core.attention import attention_forward +from ..core.gradient import gradient_checkpoint_forward + +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutput +from transformers.processing_utils import Unpack +from transformers.utils import can_return_tuple, logging +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3MLP, + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + apply_rotary_pos_emb, +) +from vector_quantize_pytorch import ResidualFSQ + +logger = logging.get_logger(__name__) + + +def create_4d_mask( + seq_len: int, + dtype: torch.dtype, + device: torch.device, + attention_mask: Optional[torch.Tensor] = None, + sliding_window: Optional[int] = None, + is_sliding_window: bool = False, + is_causal: bool = True, +) -> torch.Tensor: + indices = torch.arange(seq_len, device=device) + diff = indices.unsqueeze(1) - indices.unsqueeze(0) + valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool) + if is_causal: + valid_mask = valid_mask & (diff >= 0) + if is_sliding_window and sliding_window is not None: + if is_causal: + valid_mask = valid_mask & (diff <= sliding_window) + else: + valid_mask = valid_mask & (torch.abs(diff) <= sliding_window) + valid_mask = valid_mask.unsqueeze(0).unsqueeze(0) + if attention_mask is not None: + padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool) + valid_mask = valid_mask & padding_mask_4d + min_dtype = torch.finfo(dtype).min + mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device) + mask_tensor.masked_fill_(valid_mask, 0.0) + return mask_tensor + + +class Lambda(nn.Module): + def __init__(self, func): + super().__init__() + self.func = func + + def forward(self, x): + return self.func(x) + + +class AceStepAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + rms_norm_eps: float, + attention_bias: bool, + attention_dropout: float, + layer_types: list, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, + layer_idx: int = 0, + is_cross_attention: bool = False, + is_causal: bool = False, + ): + super().__init__() + self.layer_idx = layer_idx + self.head_dim = head_dim or hidden_size // num_attention_heads + self.num_key_value_groups = num_attention_heads // num_key_value_heads + self.scaling = self.head_dim ** -0.5 + self.attention_dropout = attention_dropout + if is_cross_attention: + is_causal = False + self.is_causal = is_causal + self.is_cross_attention = is_cross_attention + + self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias) + self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias) + self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias) + self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps) + self.attention_type = layer_types[layer_idx] + self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] = None, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + + is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None + + if is_cross_attention: + encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + curr_past_key_value = past_key_value.cross_attention_cache + if not is_updated: + key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) + past_key_value.is_updated[self.layer_idx] = True + else: + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + + else: + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if self.num_key_value_groups > 1: + key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2) + value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2) + + attn_output = attention_forward( + query_states, key_states, value_states, + q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", + attn_mask=attention_mask, + ) + attn_weights = None + + attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class AceStepEncoderLayer(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_attention_heads: int, + num_key_value_heads: int, + rms_norm_eps: float, + attention_bias: bool, + attention_dropout: float, + layer_types: list, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, + layer_idx: int = 0, + ): + super().__init__() + self.hidden_size = hidden_size + self.layer_idx = layer_idx + + self.self_attn = AceStepAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + layer_idx=layer_idx, + is_cross_attention=False, + is_causal=False, + ) + self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + + mlp_config = type('Config', (), { + 'hidden_size': hidden_size, + 'intermediate_size': intermediate_size, + 'hidden_act': 'silu', + })() + self.mlp = Qwen3MLP(mlp_config) + self.attention_type = layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + use_cache=False, + past_key_value=None, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + return outputs + + +class AttentionPooler(nn.Module): + """Pools every pool_window_size frames into 1 representation via transformer + CLS token.""" + + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + attention_dropout: float = 0.0, + layer_types: Optional[list] = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = 128, + use_sliding_window: bool = True, + rope_theta: float = 1000000, + max_position_embeddings: int = 32768, + initializer_range: float = 0.02, + num_attention_pooler_hidden_layers: int = 2, + **kwargs, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Default matches target library config (24 alternating entries). + self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12) + self.head_dim = head_dim or hidden_size // num_attention_heads + self.sliding_window = sliding_window + self.use_sliding_window = use_sliding_window + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers + self._attn_implementation = kwargs.get("_attn_implementation", "sdpa") + + self.embed_tokens = nn.Linear(hidden_size, hidden_size) + self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + # Slice layer_types to our own layer count + pooler_layer_types = self.layer_types[:num_attention_pooler_hidden_layers] + rope_config = type('RopeConfig', (), { + 'hidden_size': hidden_size, + 'num_attention_heads': num_attention_heads, + 'num_key_value_heads': num_key_value_heads, + 'head_dim': head_dim, + 'max_position_embeddings': max_position_embeddings, + 'rope_theta': rope_theta, + 'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta}, + 'rms_norm_eps': rms_norm_eps, + 'attention_bias': attention_bias, + 'attention_dropout': attention_dropout, + 'hidden_act': 'silu', + 'intermediate_size': intermediate_size, + 'layer_types': pooler_layer_types, + 'sliding_window': sliding_window, + '_attn_implementation': self._attn_implementation, + })() + self.rotary_emb = Qwen3RotaryEmbedding(rope_config) + self.gradient_checkpointing = False + self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=pooler_layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + layer_idx=layer_idx, + ) + for layer_idx in range(num_attention_pooler_hidden_layers) + ]) + + @can_return_tuple + def forward( + self, + x, + attention_mask: Optional[torch.Tensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.Tensor: + B, T, P, D = x.shape + x = self.embed_tokens(x) + special_tokens = self.special_token.expand(B, T, 1, -1) + x = torch.cat([special_tokens, x], dim=2) + x = rearrange(x, "b t p c -> (b t) p c") + + cache_position = torch.arange(0, x.shape[1], device=x.device) + position_ids = cache_position.unsqueeze(0) + hidden_states = x + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + seq_len = x.shape[1] + dtype = x.dtype + device = x.device + + full_attn_mask = create_4d_mask( + seq_len=seq_len, dtype=dtype, device=device, + attention_mask=attention_mask, sliding_window=None, + is_sliding_window=False, is_causal=False + ) + sliding_attn_mask = None + if self.use_sliding_window: + sliding_attn_mask = create_4d_mask( + seq_len=seq_len, dtype=dtype, device=device, + attention_mask=attention_mask, sliding_window=self.sliding_window, + is_sliding_window=True, is_causal=False + ) + + self_attn_mask_mapping = { + "full_attention": full_attn_mask, + "sliding_attention": sliding_attn_mask, + } + + for layer_module in self.layers: + layer_outputs = layer_module( + hidden_states, position_embeddings, + attention_mask=self_attn_mask_mapping[layer_module.attention_type], + **flash_attn_kwargs, + ) + hidden_states = layer_outputs[0] + + hidden_states = self.norm(hidden_states) + cls_output = hidden_states[:, 0, :] + return rearrange(cls_output, "(b t) c -> b t c", b=B) + + +class AceStepAudioTokenizer(nn.Module): + """Converts continuous acoustic features (VAE latents) into discrete quantized tokens. + + Input: [B, T, 64] (VAE latent dim) + Output: quantized [B, T/5, 2048], indices [B, T/5, 1] + """ + + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + attention_dropout: float = 0.0, + layer_types: Optional[list] = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = 128, + use_sliding_window: bool = True, + rope_theta: float = 1000000, + max_position_embeddings: int = 32768, + initializer_range: float = 0.02, + audio_acoustic_hidden_dim: int = 64, + pool_window_size: int = 5, + fsq_dim: int = 2048, + fsq_input_levels: list = None, + fsq_input_num_quantizers: int = 1, + num_attention_pooler_hidden_layers: int = 2, + **kwargs, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Default matches target library config (24 alternating entries). + self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12) + self.head_dim = head_dim or hidden_size // num_attention_heads + self.sliding_window = sliding_window + self.use_sliding_window = use_sliding_window + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim + self.pool_window_size = pool_window_size + self.fsq_dim = fsq_dim + self.fsq_input_levels = fsq_input_levels or [8, 8, 8, 5, 5, 5] + self.fsq_input_num_quantizers = fsq_input_num_quantizers + self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers + self._attn_implementation = kwargs.get("_attn_implementation", "sdpa") + + self.audio_acoustic_proj = nn.Linear(audio_acoustic_hidden_dim, hidden_size) + # Slice layer_types for the attention pooler + pooler_layer_types = self.layer_types[:num_attention_pooler_hidden_layers] + self.attention_pooler = AttentionPooler( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=pooler_layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + use_sliding_window=use_sliding_window, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers, + ) + self.quantizer = ResidualFSQ( + dim=self.fsq_dim, + levels=self.fsq_input_levels, + num_quantizers=self.fsq_input_num_quantizers, + force_quantization_f32=False, # avoid autocast bug in vector_quantize_pytorch + ) + + @can_return_tuple + def forward( + self, + hidden_states: Optional[torch.FloatTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.audio_acoustic_proj(hidden_states) + hidden_states = self.attention_pooler(hidden_states) + quantized, indices = self.quantizer(hidden_states) + return quantized, indices + + def tokenize(self, x): + """Convenience: takes [B, T, 64], rearranges to patches, runs forward.""" + x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=self.pool_window_size) + return self.forward(x) + + +class AudioTokenDetokenizer(nn.Module): + """Converts quantized audio tokens back to continuous acoustic representations. + + Input: [B, T/5, hidden_size] (quantized vectors) + Output: [B, T, 64] (VAE-latent-shaped continuous features) + """ + + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + attention_dropout: float = 0.0, + layer_types: Optional[list] = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = 128, + use_sliding_window: bool = True, + rope_theta: float = 1000000, + max_position_embeddings: int = 32768, + initializer_range: float = 0.02, + pool_window_size: int = 5, + audio_acoustic_hidden_dim: int = 64, + num_attention_pooler_hidden_layers: int = 2, + **kwargs, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Default matches target library config (24 alternating entries). + self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12) + self.head_dim = head_dim or hidden_size // num_attention_heads + self.sliding_window = sliding_window + self.use_sliding_window = use_sliding_window + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.pool_window_size = pool_window_size + self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim + self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers + self._attn_implementation = kwargs.get("_attn_implementation", "sdpa") + + self.embed_tokens = nn.Linear(hidden_size, hidden_size) + self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) + # Slice layer_types to our own layer count (use num_audio_decoder_hidden_layers) + detok_layer_types = self.layer_types[:num_attention_pooler_hidden_layers] + rope_config = type('RopeConfig', (), { + 'hidden_size': hidden_size, + 'num_attention_heads': num_attention_heads, + 'num_key_value_heads': num_key_value_heads, + 'head_dim': head_dim, + 'max_position_embeddings': max_position_embeddings, + 'rope_theta': rope_theta, + 'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta}, + 'rms_norm_eps': rms_norm_eps, + 'attention_bias': attention_bias, + 'attention_dropout': attention_dropout, + 'hidden_act': 'silu', + 'intermediate_size': intermediate_size, + 'layer_types': detok_layer_types, + 'sliding_window': sliding_window, + '_attn_implementation': self._attn_implementation, + })() + self.rotary_emb = Qwen3RotaryEmbedding(rope_config) + self.gradient_checkpointing = False + self.special_tokens = nn.Parameter(torch.randn(1, pool_window_size, hidden_size) * 0.02) + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=detok_layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + layer_idx=layer_idx, + ) + for layer_idx in range(num_attention_pooler_hidden_layers) + ]) + self.proj_out = nn.Linear(hidden_size, audio_acoustic_hidden_dim) + + @can_return_tuple + def forward( + self, + x, + attention_mask: Optional[torch.Tensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.Tensor: + B, T, D = x.shape + x = self.embed_tokens(x) + x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1) + special_tokens = self.special_tokens.expand(B, T, -1, -1) + x = x + special_tokens + x = rearrange(x, "b t p c -> (b t) p c") + + cache_position = torch.arange(0, x.shape[1], device=x.device) + position_ids = cache_position.unsqueeze(0) + hidden_states = x + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + seq_len = x.shape[1] + dtype = x.dtype + device = x.device + + full_attn_mask = create_4d_mask( + seq_len=seq_len, dtype=dtype, device=device, + attention_mask=attention_mask, sliding_window=None, + is_sliding_window=False, is_causal=False + ) + sliding_attn_mask = None + if self.use_sliding_window: + sliding_attn_mask = create_4d_mask( + seq_len=seq_len, dtype=dtype, device=device, + attention_mask=attention_mask, sliding_window=self.sliding_window, + is_sliding_window=True, is_causal=False + ) + + self_attn_mask_mapping = { + "full_attention": full_attn_mask, + "sliding_attention": sliding_attn_mask, + } + + for layer_module in self.layers: + layer_outputs = layer_module( + hidden_states, position_embeddings, + attention_mask=self_attn_mask_mapping[layer_module.attention_type], + **flash_attn_kwargs, + ) + hidden_states = layer_outputs[0] + + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_out(hidden_states) + return rearrange(hidden_states, "(b t) p c -> b (t p) c", b=B, p=self.pool_window_size) + + +class AceStepTokenizer(nn.Module): + """Container for AceStepAudioTokenizer + AudioTokenDetokenizer. + + Provides encode/decode convenience methods for VAE latent discretization. + Used in cover song mode to convert source audio latents to discrete tokens + and back to continuous conditioning hints. + """ + + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + attention_dropout: float = 0.0, + layer_types: Optional[list] = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = 128, + use_sliding_window: bool = True, + rope_theta: float = 1000000, + max_position_embeddings: int = 32768, + initializer_range: float = 0.02, + audio_acoustic_hidden_dim: int = 64, + pool_window_size: int = 5, + fsq_dim: int = 2048, + fsq_input_levels: list = None, + fsq_input_num_quantizers: int = 1, + num_attention_pooler_hidden_layers: int = 2, + num_audio_decoder_hidden_layers: int = 24, + **kwargs, + ): + super().__init__() + # Default layer_types matches target library config (24 alternating entries). + # Sub-modules (pooler/detokenizer) slice first N entries for their own layer count. + if layer_types is None: + layer_types = ["sliding_attention", "full_attention"] * 12 + self.tokenizer = AceStepAudioTokenizer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + use_sliding_window=use_sliding_window, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + audio_acoustic_hidden_dim=audio_acoustic_hidden_dim, + pool_window_size=pool_window_size, + fsq_dim=fsq_dim, + fsq_input_levels=fsq_input_levels, + fsq_input_num_quantizers=fsq_input_num_quantizers, + num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers, + **kwargs, + ) + self.detokenizer = AudioTokenDetokenizer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=layer_types, + head_dim=head_dim, + sliding_window=sliding_window, + use_sliding_window=use_sliding_window, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + pool_window_size=pool_window_size, + audio_acoustic_hidden_dim=audio_acoustic_hidden_dim, + num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers, + **kwargs, + ) + + def encode(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """VAE latent [B, T, 64] → discrete tokens.""" + return self.tokenizer(hidden_states) + + def decode(self, quantized: torch.Tensor) -> torch.Tensor: + """Discrete tokens [B, T/5, hidden_size] → continuous [B, T, 64].""" + return self.detokenizer(quantized) + + def tokenize(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convenience: [B, T, 64] → quantized + indices via patch rearrangement.""" + return self.tokenizer.tokenize(x) diff --git a/diffsynth/models/ace_step_vae.py b/diffsynth/models/ace_step_vae.py new file mode 100644 index 0000000..dd78a0a --- /dev/null +++ b/diffsynth/models/ace_step_vae.py @@ -0,0 +1,241 @@ +# Copyright 2025 The ACESTEO Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ACE-Step Audio VAE (AutoencoderOobleck CNN architecture). + +This is a CNN-based VAE for audio waveform encoding/decoding. +It uses weight-normalized convolutions and Snake1d activations. +Does NOT depend on diffusers — pure nn.Module implementation. +""" +import math +from typing import Optional + +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm + + +class Snake1d(nn.Module): + """Snake activation: x + 1/(beta+eps) * sin(alpha*x)^2.""" + + def __init__(self, hidden_dim: int, logscale: bool = True): + super().__init__() + self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1)) + self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1)) + self.logscale = logscale + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + shape = hidden_states.shape + alpha = torch.exp(self.alpha) if self.logscale else self.alpha + beta = torch.exp(self.beta) if self.logscale else self.beta + hidden_states = hidden_states.reshape(shape[0], shape[1], -1) + hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2) + return hidden_states.reshape(shape) + + +class OobleckResidualUnit(nn.Module): + """Residual unit: Snake1d → Conv1d(dilated) → Snake1d → Conv1d(1×1) + skip.""" + + def __init__(self, dimension: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.snake1 = Snake1d(dimension) + self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)) + self.snake2 = Snake1d(dimension) + self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + output = self.conv1(self.snake1(hidden_state)) + output = self.conv2(self.snake2(output)) + padding = (hidden_state.shape[-1] - output.shape[-1]) // 2 + if padding > 0: + hidden_state = hidden_state[..., padding:-padding] + return hidden_state + output + + +class OobleckEncoderBlock(nn.Module): + """Encoder block: 3 residual units + downsampling conv.""" + + def __init__(self, input_dim: int, output_dim: int, stride: int = 1): + super().__init__() + self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1) + self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3) + self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9) + self.snake1 = Snake1d(input_dim) + self.conv1 = weight_norm( + nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)) + ) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.res_unit1(hidden_state) + hidden_state = self.res_unit2(hidden_state) + hidden_state = self.snake1(self.res_unit3(hidden_state)) + return self.conv1(hidden_state) + + +class OobleckDecoderBlock(nn.Module): + """Decoder block: upsampling conv + 3 residual units.""" + + def __init__(self, input_dim: int, output_dim: int, stride: int = 1): + super().__init__() + self.snake1 = Snake1d(input_dim) + self.conv_t1 = weight_norm( + nn.ConvTranspose1d( + input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), + ) + ) + self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1) + self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3) + self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.snake1(hidden_state) + hidden_state = self.conv_t1(hidden_state) + hidden_state = self.res_unit1(hidden_state) + hidden_state = self.res_unit2(hidden_state) + return self.res_unit3(hidden_state) + + +class OobleckEncoder(nn.Module): + """Full encoder: audio → latent representation [B, encoder_hidden_size, T']. + + conv1 → [blocks] → snake1 → conv2 + """ + + def __init__( + self, + encoder_hidden_size: int = 128, + audio_channels: int = 2, + downsampling_ratios: list = None, + channel_multiples: list = None, + ): + super().__init__() + downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10] + channel_multiples = channel_multiples or [1, 2, 4, 8, 16] + channel_multiples = [1] + channel_multiples + + self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3)) + + self.block = nn.ModuleList() + for stride_index, stride in enumerate(downsampling_ratios): + self.block.append( + OobleckEncoderBlock( + input_dim=encoder_hidden_size * channel_multiples[stride_index], + output_dim=encoder_hidden_size * channel_multiples[stride_index + 1], + stride=stride, + ) + ) + + d_model = encoder_hidden_size * channel_multiples[-1] + self.snake1 = Snake1d(d_model) + self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.conv1(hidden_state) + for block in self.block: + hidden_state = block(hidden_state) + hidden_state = self.snake1(hidden_state) + return self.conv2(hidden_state) + + +class OobleckDecoder(nn.Module): + """Full decoder: latent → audio waveform [B, audio_channels, T]. + + conv1 → [blocks] → snake1 → conv2(no bias) + """ + + def __init__( + self, + channels: int = 128, + input_channels: int = 64, + audio_channels: int = 2, + upsampling_ratios: list = None, + channel_multiples: list = None, + ): + super().__init__() + upsampling_ratios = upsampling_ratios or [10, 6, 4, 4, 2] + channel_multiples = channel_multiples or [1, 2, 4, 8, 16] + channel_multiples = [1] + channel_multiples + + self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3)) + + self.block = nn.ModuleList() + for stride_index, stride in enumerate(upsampling_ratios): + self.block.append( + OobleckDecoderBlock( + input_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index], + output_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index - 1], + stride=stride, + ) + ) + + self.snake1 = Snake1d(channels) + # conv2 has no bias (matches checkpoint: only weight_g/weight_v, no bias key) + self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.conv1(hidden_state) + for block in self.block: + hidden_state = block(hidden_state) + hidden_state = self.snake1(hidden_state) + return self.conv2(hidden_state) + + +class AceStepVAE(nn.Module): + """Audio VAE for ACE-Step (AutoencoderOobleck architecture). + + Encodes audio waveform → latent, decodes latent → audio waveform. + Uses Snake1d activations and weight-normalized convolutions. + """ + + def __init__( + self, + encoder_hidden_size: int = 128, + downsampling_ratios: list = None, + channel_multiples: list = None, + decoder_channels: int = 128, + decoder_input_channels: int = 64, + audio_channels: int = 2, + sampling_rate: int = 48000, + ): + super().__init__() + downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10] + channel_multiples = channel_multiples or [1, 2, 4, 8, 16] + upsampling_ratios = downsampling_ratios[::-1] + + self.encoder = OobleckEncoder( + encoder_hidden_size=encoder_hidden_size, + audio_channels=audio_channels, + downsampling_ratios=downsampling_ratios, + channel_multiples=channel_multiples, + ) + self.decoder = OobleckDecoder( + channels=decoder_channels, + input_channels=decoder_input_channels, + audio_channels=audio_channels, + upsampling_ratios=upsampling_ratios, + channel_multiples=channel_multiples, + ) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Audio waveform [B, audio_channels, T] → latent [B, encoder_hidden_size, T'].""" + return self.encoder(x) + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """Latent [B, encoder_hidden_size, T] → audio waveform [B, audio_channels, T'].""" + return self.decoder(z) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + """Full round-trip: encode → decode.""" + z = self.encode(sample) + return self.decoder(z) diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py new file mode 100644 index 0000000..09c3178 --- /dev/null +++ b/diffsynth/pipelines/ace_step.py @@ -0,0 +1,527 @@ +""" +ACE-Step Pipeline for DiffSynth-Studio. + +Text-to-Music generation pipeline using ACE-Step 1.5 model. +""" +import torch +from typing import Optional +from tqdm import tqdm + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from ..models.ace_step_dit import AceStepDiTModel +from ..models.ace_step_conditioner import AceStepConditionEncoder +from ..models.ace_step_text_encoder import AceStepTextEncoder +from ..models.ace_step_vae import AceStepVAE + + +class AceStepPipeline(BasePipeline): + """Pipeline for ACE-Step text-to-music generation.""" + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, + torch_dtype=torch_dtype, + height_division_factor=1, + width_division_factor=1, + ) + self.scheduler = FlowMatchScheduler("ACE-Step") + self.text_encoder: AceStepTextEncoder = None + self.conditioner: AceStepConditionEncoder = None + self.dit: AceStepDiTModel = None + self.vae = None # AutoencoderOobleck (diffusers) or AceStepVAE + + # Unit chain order — 7 units total + # + # 1. ShapeChecker: duration → seq_len + # 2. PromptEmbedder: prompt/lyrics → text/lyric embeddings (shared for CFG) + # 3. SilenceLatentInitializer: seq_len → src_latents + chunk_masks + # 4. ContextLatentBuilder: src_latents + chunk_masks → context_latents (shared, same for CFG+) + # 5. ConditionEmbedder: text/lyric → encoder_hidden_states (separate for CFG+/-) + # 6. NoiseInitializer: context_latents → noise + # 7. InputAudioEmbedder: noise → latents + # + # ContextLatentBuilder runs before ConditionEmbedder so that + # context_latents is available for noise shape computation. + self.in_iteration_models = ("dit",) + self.units = [ + AceStepUnit_ShapeChecker(), + AceStepUnit_PromptEmbedder(), + AceStepUnit_SilenceLatentInitializer(), + AceStepUnit_ContextLatentBuilder(), + AceStepUnit_ConditionEmbedder(), + AceStepUnit_NoiseInitializer(), + AceStepUnit_InputAudioEmbedder(), + ] + self.model_fn = model_fn_ace_step + self.compilable_models = ["dit"] + + self.sample_rate = 48000 + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: str = get_device_type(), + model_configs: list[ModelConfig] = [], + text_tokenizer_config: ModelConfig = None, + vram_limit: float = None, + ): + """Load pipeline from pretrained checkpoints.""" + pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + pipe.text_encoder = model_pool.fetch_model("ace_step_text_encoder") + pipe.conditioner = model_pool.fetch_model("ace_step_conditioner") + pipe.dit = model_pool.fetch_model("ace_step_dit") + pipe.vae = model_pool.fetch_model("ace_step_vae") + + if text_tokenizer_config is not None: + text_tokenizer_config.download_if_necessary() + from transformers import AutoTokenizer + pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + # Lyrics + lyrics: str = "", + # Reference audio (optional, for timbre conditioning) + reference_audio = None, + # Shape + duration: float = 60.0, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 8, + # Scheduler-specific parameters + shift: float = 3.0, + # Progress + progress_bar_cmd=tqdm, + ): + # 1. Scheduler + self.scheduler.set_timesteps( + num_inference_steps=num_inference_steps, + denoising_strength=1.0, + shift=shift, + ) + + # 2. 三字典输入 + inputs_posi = {"prompt": prompt} + inputs_nega = {"negative_prompt": negative_prompt} + inputs_shared = { + "cfg_scale": cfg_scale, + "lyrics": lyrics, + "reference_audio": reference_audio, + "duration": duration, + "seed": seed, + "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + "shift": shift, + } + + # 3. Unit 链执行 + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner( + unit, self, inputs_shared, inputs_posi, inputs_nega + ) + + # 4. Denoise loop + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step( + self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared + ) + + # 5. VAE 解码 + self.load_models_to_device(['vae']) + # DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first) + latents = inputs_shared["latents"].transpose(1, 2) + vae_output = self.vae.decode(latents) + # VAE returns OobleckDecoderOutput with .sample attribute + audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output + audio = self.output_audio_format_check(audio_output) + self.load_models_to_device([]) + return audio + + def output_audio_format_check(self, audio_output): + """Convert VAE output to standard audio format [C, T], float32. + + VAE decode outputs [B, C, T] (audio waveform). + We squeeze batch dim and return [C, T]. + """ + if audio_output.ndim == 3: + audio_output = audio_output.squeeze(0) + return audio_output.float() + + +class AceStepUnit_ShapeChecker(PipelineUnit): + """Check and compute sequence length from duration.""" + def __init__(self): + super().__init__( + input_params=("duration",), + output_params=("duration", "seq_len"), + ) + + def process(self, pipe, duration): + # ACE-Step: 25 Hz latent rate + seq_len = int(duration * 25) + return {"duration": duration, "seq_len": seq_len} + + +class AceStepUnit_PromptEmbedder(PipelineUnit): + """Encode prompt and lyrics using Qwen3-Embedding. + + Uses seperate_cfg=True to read prompt from inputs_posi (not inputs_shared). + The negative condition uses null_condition_emb (handled by ConditionEmbedder), + so negative text encoding is not needed here. + """ + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={}, + input_params=("lyrics",), + output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"), + onload_model_names=("text_encoder",) + ) + + def _encode_text(self, pipe, text): + """Encode text using Qwen3-Embedding → [B, T, 1024].""" + if pipe.tokenizer is None: + return None, None + text_inputs = pipe.tokenizer( + text, + padding="max_length", + max_length=512, + truncation=True, + return_tensors="pt", + ) + input_ids = text_inputs.input_ids.to(pipe.device) + attention_mask = text_inputs.attention_mask.to(pipe.device) + hidden_states = pipe.text_encoder(input_ids, attention_mask) + return hidden_states, attention_mask + + def process(self, pipe, prompt, lyrics, negative_prompt=None): + pipe.load_models_to_device(['text_encoder']) + + text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt) + + # Lyrics encoding — use empty string if not provided + lyric_text = lyrics if lyrics else "" + lyric_hidden_states, lyric_attention_mask = self._encode_text(pipe, lyric_text) + + if text_hidden_states is not None and lyric_hidden_states is not None: + return { + "text_hidden_states": text_hidden_states, + "text_attention_mask": text_attention_mask, + "lyric_hidden_states": lyric_hidden_states, + "lyric_attention_mask": lyric_attention_mask, + } + return {} + + +class AceStepUnit_SilenceLatentInitializer(PipelineUnit): + """Generate silence latent (all zeros) and chunk_masks for text2music. + + Target library reference: `prepare_condition()` line 1698-1699: + context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1) + + For text2music mode: + - src_latents = zeros [B, T, 64] (VAE latent dimension) + - chunk_masks = ones [B, T, 64] (full visibility mask for text2music) + - context_latents = [B, T, 128] (concat of src_latents + chunk_masks) + """ + def __init__(self): + super().__init__( + input_params=("seq_len",), + output_params=("silence_latent", "src_latents", "chunk_masks"), + ) + + def process(self, pipe, seq_len): + # silence_latent shape: [B, T, 64] — 64 is the VAE latent dimension + silence_latent = torch.zeros(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype) + # For text2music: src_latents = silence_latent + src_latents = silence_latent.clone() + + # chunk_masks: [B, T, 64] of ones (same shape as src_latents) + # In text2music mode (is_covers=0), chunk_masks are all 1.0 + # This matches the target library's behavior at line 1699 + chunk_masks = torch.ones(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype) + + return {"silence_latent": silence_latent, "src_latents": src_latents, "chunk_masks": chunk_masks} + + +class AceStepUnit_ContextLatentBuilder(PipelineUnit): + """Build context_latents from src_latents and chunk_masks. + + Target library reference: `prepare_condition()` line 1699: + context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1) + + context_latents is the SAME for positive and negative CFG paths + (it comes from src_latents + chunk_masks, not from text encoding). + So this is a普通模式 Unit — outputs go to inputs_shared. + """ + def __init__(self): + super().__init__( + input_params=("src_latents", "chunk_masks"), + output_params=("context_latents", "attention_mask"), + ) + + def process(self, pipe, src_latents, chunk_masks): + # context_latents: cat([src_latents, chunk_masks], dim=-1) → [B, T, 128] + context_latents = torch.cat([src_latents, chunk_masks], dim=-1) + + # attention_mask for the DiT: ones [B, T] + # The target library uses this for cross-attention with context_latents + attention_mask = torch.ones(src_latents.shape[0], src_latents.shape[1], + device=pipe.device, dtype=pipe.torch_dtype) + + return {"context_latents": context_latents, "attention_mask": attention_mask} + + +class AceStepUnit_ConditionEmbedder(PipelineUnit): + """Generate encoder_hidden_states via ACEStepConditioner. + + Target library reference: `prepare_condition()` line 1674-1681: + encoder_hidden_states, encoder_attention_mask = self.encoder(...) + + Uses seperate_cfg mode: + - Positive: encode with full condition (text + lyrics + reference audio) + - Negative: replace text with null_condition_emb, keep lyrics/timbre same + + context_latents is handled by ContextLatentBuilder (普通模式), not here. + """ + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={ + "text_hidden_states": "text_hidden_states", + "text_attention_mask": "text_attention_mask", + "lyric_hidden_states": "lyric_hidden_states", + "lyric_attention_mask": "lyric_attention_mask", + "reference_audio": "reference_audio", + "refer_audio_order_mask": "refer_audio_order_mask", + }, + input_params_nega={}, + input_params=("cfg_scale",), + output_params=( + "encoder_hidden_states", "encoder_attention_mask", + "negative_encoder_hidden_states", "negative_encoder_attention_mask", + ), + onload_model_names=("conditioner",) + ) + + def _prepare_condition(self, pipe, text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed=None, + refer_audio_order_mask=None): + """Call ACEStepConditioner forward to produce encoder_hidden_states.""" + pipe.load_models_to_device(['conditioner']) + + # Handle reference audio + if refer_audio_acoustic_hidden_states_packed is None: + # No reference audio: create 2D packed zeros [N=1, d=64] + # TimbreEncoder.unpack expects [N, d], not [B, T, d] + refer_audio_acoustic_hidden_states_packed = torch.zeros( + 1, 64, device=pipe.device, dtype=pipe.torch_dtype + ) + refer_audio_order_mask = torch.LongTensor([0]).to(pipe.device) + + encoder_hidden_states, encoder_attention_mask = pipe.conditioner( + text_hidden_states=text_hidden_states, + text_attention_mask=text_attention_mask, + lyric_hidden_states=lyric_hidden_states, + lyric_attention_mask=lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, + refer_audio_order_mask=refer_audio_order_mask, + ) + + return encoder_hidden_states, encoder_attention_mask + + def _prepare_negative_condition(self, pipe, lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed=None, + refer_audio_order_mask=None): + """Generate negative condition using null_condition_emb.""" + if pipe.conditioner is None or not hasattr(pipe.conditioner, 'null_condition_emb'): + return None, None + + null_emb = pipe.conditioner.null_condition_emb # [1, 1, hidden_size] + bsz = 1 + if lyric_hidden_states is not None: + bsz = lyric_hidden_states.shape[0] + null_hidden_states = null_emb.expand(bsz, -1, -1) + null_attn_mask = torch.ones(bsz, 1, device=pipe.device, dtype=pipe.torch_dtype) + + # For negative: use null_condition_emb as text, keep lyrics and timbre + neg_encoder_hidden_states, neg_encoder_attention_mask = pipe.conditioner( + text_hidden_states=null_hidden_states, + text_attention_mask=null_attn_mask, + lyric_hidden_states=lyric_hidden_states, + lyric_attention_mask=lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, + refer_audio_order_mask=refer_audio_order_mask, + ) + + return neg_encoder_hidden_states, neg_encoder_attention_mask + + def process(self, pipe, text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + reference_audio=None, refer_audio_order_mask=None, + negative_prompt=None, cfg_scale=1.0): + + # Positive condition + pos_enc_hs, pos_enc_mask = self._prepare_condition( + pipe, text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + None, refer_audio_order_mask, + ) + + # Negative condition: only needed when CFG is active (cfg_scale > 1.0) + # For cfg_scale=1.0 (turbo), skip to avoid null_condition_emb dimension mismatch + result = { + "encoder_hidden_states": pos_enc_hs, + "encoder_attention_mask": pos_enc_mask, + } + + if cfg_scale > 1.0: + neg_enc_hs, neg_enc_mask = self._prepare_negative_condition( + pipe, lyric_hidden_states, lyric_attention_mask, + None, refer_audio_order_mask, + ) + if neg_enc_hs is not None: + result["negative_encoder_hidden_states"] = neg_enc_hs + result["negative_encoder_attention_mask"] = neg_enc_mask + + return result + + +class AceStepUnit_NoiseInitializer(PipelineUnit): + """Generate initial noise tensor. + + Target library reference: `prepare_noise()` line 1781-1818: + src_latents_shape = (bsz, context_latents.shape[1], context_latents.shape[-1] // 2) + + Noise shape = [B, T, context_latents.shape[-1] // 2] = [B, T, 128 // 2] = [B, T, 64] + """ + def __init__(self): + super().__init__( + input_params=("seed", "seq_len", "rand_device", "context_latents"), + output_params=("noise",), + ) + + def process(self, pipe, seed, seq_len, rand_device, context_latents): + # Noise shape: [B, T, context_latents.shape[-1] // 2] + # context_latents = [B, T, 128] → noise = [B, T, 64] + # This matches the target library's prepare_noise() at line 1796 + noise_shape = (context_latents.shape[0], context_latents.shape[1], + context_latents.shape[-1] // 2) + noise = pipe.generate_noise( + noise_shape, + seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype + ) + return {"noise": noise} + + +class AceStepUnit_InputAudioEmbedder(PipelineUnit): + """Set up latents for denoise loop. + + For text2music (no input audio): latents = noise, input_latents = None. + + Target library reference: `generate_audio()` line 1972: + xt = noise (when cover_noise_strength == 0) + """ + def __init__(self): + super().__init__( + input_params=("noise",), + output_params=("latents", "input_latents"), + ) + + def process(self, pipe, noise): + # For text2music: start from pure noise + return {"latents": noise, "input_latents": None} + + +def model_fn_ace_step( + dit: AceStepDiTModel, + latents=None, + timestep=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + context_latents=None, + attention_mask=None, + past_key_values=None, + negative_encoder_hidden_states=None, + negative_encoder_attention_mask=None, + negative_context_latents=None, + **kwargs, +): + """Model function for ACE-Step DiT forward. + + Timestep is already in [0, 1] range — no scaling needed. + + Target library reference: `generate_audio()` line 2009-2020: + decoder_outputs = self.decoder( + hidden_states=x, timestep=t_curr_tensor, timestep_r=t_curr_tensor, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + context_latents=context_latents, + use_cache=True, past_key_values=past_key_values, + ) + + Args: + dit: AceStepDiTModel + latents: [B, T, 64] noise/latent tensor (same shape as src_latents) + timestep: scalar tensor in [0, 1] + encoder_hidden_states: [B, T_text, 2048] condition from Conditioner + (positive or negative depending on CFG pass — the cfg_guided_model_fn + passes inputs_posi for positive, inputs_nega for negative) + encoder_attention_mask: [B, T_text] + context_latents: [B, T, 128] = cat([src_latents, chunk_masks], dim=-1) + (same for both CFG+/- paths in text2music mode) + attention_mask: [B, T] ones mask for DiT + past_key_values: EncoderDecoderCache for KV caching + + The DiT internally concatenates: cat([context_latents, latents], dim=-1) = [B, T, 192] + as the actual input (128 + 64 = 192 channels). + """ + # ACE-Step uses timestep directly in [0, 1] range — no /1000 scaling + timestep = timestep.squeeze() + + # Expand timestep to match batch size + bsz = latents.shape[0] + timestep = timestep.expand(bsz) + + decoder_outputs = dit( + hidden_states=latents, + timestep=timestep, + timestep_r=timestep, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + context_latents=context_latents, + use_cache=True, + past_key_values=past_key_values, + ) + + # Return velocity prediction (first element of decoder_outputs) + return decoder_outputs[0] diff --git a/diffsynth/utils/state_dict_converters/ace_step_conditioner.py b/diffsynth/utils/state_dict_converters/ace_step_conditioner.py new file mode 100644 index 0000000..b6984b8 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ace_step_conditioner.py @@ -0,0 +1,48 @@ +""" +State dict converter for ACE-Step Conditioner model. + +The original checkpoint stores all model weights in a single file +(nested in AceStepConditionGenerationModel). The Conditioner weights are +prefixed with 'encoder.'. + +This converter extracts only keys starting with 'encoder.' and strips +the prefix to match the standalone AceStepConditionEncoder in DiffSynth. +""" + + +def ace_step_conditioner_converter(state_dict): + """ + Convert ACE-Step Conditioner checkpoint keys to DiffSynth format. + + 参数 state_dict 是 DiskMap 类型。 + 遍历时,key 是 key 名,state_dict[key] 获取实际值。 + + Original checkpoint contains all model weights under prefixes: + - decoder.* (DiT) + - encoder.* (Conditioner) + - tokenizer.* (Audio Tokenizer) + - detokenizer.* (Audio Detokenizer) + - null_condition_emb (CFG null embedding) + + This extracts only 'encoder.' keys and strips the prefix. + + Example mapping: + encoder.lyric_encoder.layers.0.self_attn.q_proj.weight -> lyric_encoder.layers.0.self_attn.q_proj.weight + encoder.attention_pooler.layers.0.self_attn.q_proj.weight -> attention_pooler.layers.0.self_attn.q_proj.weight + encoder.timbre_encoder.layers.0.self_attn.q_proj.weight -> timbre_encoder.layers.0.self_attn.q_proj.weight + encoder.audio_tokenizer.audio_acoustic_proj.weight -> audio_tokenizer.audio_acoustic_proj.weight + encoder.detokenizer.layers.0.self_attn.q_proj.weight -> detokenizer.layers.0.self_attn.q_proj.weight + """ + new_state_dict = {} + prefix = "encoder." + + for key in state_dict: + if key.startswith(prefix): + new_key = key[len(prefix):] + new_state_dict[new_key] = state_dict[key] + + # Extract null_condition_emb from top level (used for CFG negative condition) + if "null_condition_emb" in state_dict: + new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"] + + return new_state_dict diff --git a/diffsynth/utils/state_dict_converters/ace_step_dit.py b/diffsynth/utils/state_dict_converters/ace_step_dit.py new file mode 100644 index 0000000..758462c --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ace_step_dit.py @@ -0,0 +1,43 @@ +""" +State dict converter for ACE-Step DiT model. + +The original checkpoint stores all model weights in a single file +(nested in AceStepConditionGenerationModel). The DiT weights are +prefixed with 'decoder.'. + +This converter extracts only keys starting with 'decoder.' and strips +the prefix to match the standalone AceStepDiTModel in DiffSynth. +""" + + +def ace_step_dit_converter(state_dict): + """ + Convert ACE-Step DiT checkpoint keys to DiffSynth format. + + 参数 state_dict 是 DiskMap 类型。 + 遍历时,key 是 key 名,state_dict[key] 获取实际值。 + + Original checkpoint contains all model weights under prefixes: + - decoder.* (DiT) + - encoder.* (Conditioner) + - tokenizer.* (Audio Tokenizer) + - detokenizer.* (Audio Detokenizer) + - null_condition_emb (CFG null embedding) + + This extracts only 'decoder.' keys and strips the prefix. + + Example mapping: + decoder.layers.0.self_attn.q_proj.weight -> layers.0.self_attn.q_proj.weight + decoder.proj_in.0.linear_1.weight -> proj_in.0.linear_1.weight + decoder.time_embed.linear_1.weight -> time_embed.linear_1.weight + decoder.rotary_emb.inv_freq -> rotary_emb.inv_freq + """ + new_state_dict = {} + prefix = "decoder." + + for key in state_dict: + if key.startswith(prefix): + new_key = key[len(prefix):] + new_state_dict[new_key] = state_dict[key] + + return new_state_dict diff --git a/diffsynth/utils/state_dict_converters/ace_step_lm.py b/diffsynth/utils/state_dict_converters/ace_step_lm.py new file mode 100644 index 0000000..2067cb1 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ace_step_lm.py @@ -0,0 +1,55 @@ +""" +State dict converter for ACE-Step LLM (Qwen3-based). + +The safetensors file stores Qwen3 model weights. Different checkpoints +may have different key formats: +- Qwen3ForCausalLM format: model.embed_tokens.weight, model.layers.0.* +- Qwen3Model format: embed_tokens.weight, layers.0.* + +Qwen3ForCausalLM wraps a .model attribute (Qwen3Model), so its +state_dict() has keys: + model.model.embed_tokens.weight + model.model.layers.0.self_attn.q_proj.weight + model.model.norm.weight + model.lm_head.weight (tied to model.model.embed_tokens) + +This converter normalizes all keys to the Qwen3ForCausalLM format. + +Example mapping: + model.embed_tokens.weight -> model.model.embed_tokens.weight + embed_tokens.weight -> model.model.embed_tokens.weight + model.layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight + layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight + model.norm.weight -> model.model.norm.weight + norm.weight -> model.model.norm.weight +""" + + +def ace_step_lm_converter(state_dict): + """ + Convert ACE-Step LLM checkpoint keys to match Qwen3ForCausalLM state dict. + + 参数 state_dict 是 DiskMap 类型。 + 遍历时,key 是 key 名,state_dict[key] 获取实际值。 + """ + new_state_dict = {} + model_prefix = "model." + nested_prefix = "model.model." + + for key in state_dict: + if key.startswith(nested_prefix): + # Already has model.model., keep as is + new_key = key + elif key.startswith(model_prefix): + # Has model., add another model. + new_key = "model." + key + else: + # No prefix, add model.model. + new_key = "model.model." + key + new_state_dict[new_key] = state_dict[key] + + # Handle tied word embeddings: lm_head.weight shares with embed_tokens + if "model.model.embed_tokens.weight" in new_state_dict: + new_state_dict["model.lm_head.weight"] = new_state_dict["model.model.embed_tokens.weight"] + + return new_state_dict diff --git a/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py b/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py new file mode 100644 index 0000000..de0b6c7 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py @@ -0,0 +1,39 @@ +""" +State dict converter for ACE-Step Text Encoder (Qwen3-Embedding-0.6B). + +The safetensors stores Qwen3Model weights with keys: + embed_tokens.weight + layers.0.self_attn.q_proj.weight + norm.weight + +AceStepTextEncoder wraps a .model attribute (Qwen3Model), so its +state_dict() has keys with 'model.' prefix: + model.embed_tokens.weight + model.layers.0.self_attn.q_proj.weight + model.norm.weight + +This converter adds 'model.' prefix to match the nested structure. +""" + + +def ace_step_text_encoder_converter(state_dict): + """ + Convert ACE-Step Text Encoder checkpoint keys to match Qwen3Model wrapped state dict. + + 参数 state_dict 是 DiskMap 类型。 + 遍历时,key 是 key 名,state_dict[key] 获取实际值。 + """ + new_state_dict = {} + prefix = "model." + nested_prefix = "model.model." + + for key in state_dict: + if key.startswith(nested_prefix): + new_key = key + elif key.startswith(prefix): + new_key = "model." + key + else: + new_key = "model." + key + new_state_dict[new_key] = state_dict[key] + + return new_state_dict diff --git a/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py b/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py new file mode 100644 index 0000000..d4cb2ba --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py @@ -0,0 +1,27 @@ +""" +State dict converter for ACE-Step Tokenizer model. + +The original checkpoint stores tokenizer and detokenizer weights at the top level: +- tokenizer.* (AceStepAudioTokenizer: audio_acoustic_proj, attention_pooler, quantizer) +- detokenizer.* (AudioTokenDetokenizer: embed_tokens, layers, proj_out) + +These map directly to the AceStepTokenizer class which wraps both as +self.tokenizer and self.detokenizer submodules. +""" + + +def ace_step_tokenizer_converter(state_dict): + """ + Convert ACE-Step Tokenizer checkpoint keys to DiffSynth format. + + The checkpoint keys `tokenizer.*` and `detokenizer.*` already match + the DiffSynth AceStepTokenizer module structure (self.tokenizer, self.detokenizer). + No key remapping needed — just extract the relevant keys. + """ + new_state_dict = {} + + for key in state_dict: + if key.startswith("tokenizer.") or key.startswith("detokenizer."): + new_state_dict[key] = state_dict[key] + + return new_state_dict diff --git a/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py b/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py new file mode 100644 index 0000000..edcf2cd --- /dev/null +++ b/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py @@ -0,0 +1,180 @@ +""" +Ace-Step 1.5 — Text-to-Music with Simple Mode (LLM expansion). + +Uses the ACE-Step LLM to expand a simple description into structured +parameters (caption, lyrics, bpm, keyscale, etc.), then feeds them +to the DiffSynth Pipeline. + +The LLM expansion uses the target library's LLMHandler. If vLLM is +not available, it falls back to using pre-structured parameters. + +Usage: + python examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py +""" +import os +import sys +import json +import torch +import soundfile as sf + +from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig + + +# --------------------------------------------------------------------------- +# Simple Mode: LLM expansion +# --------------------------------------------------------------------------- + +def try_load_llm_handler(checkpoint_dir: str, lm_model_path: str = "acestep-5Hz-lm-1.7B", + backend: str = "vllm"): + """Try to load the target library's LLMHandler. Returns (handler, success).""" + try: + from acestep.llm_inference import LLMHandler + handler = LLMHandler() + status, success = handler.initialize( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + backend=backend, + ) + if success: + print(f"[Simple Mode] LLM loaded via {backend} backend: {status}") + return handler, True + else: + print(f"[Simple Mode] LLM init failed: {status}") + return None, False + except Exception as e: + print(f"[Simple Mode] LLMHandler not available: {e}") + return None, False + + +def expand_with_llm(llm_handler, description: str, duration: float = 30.0): + """Expand a simple description using LLM Chain-of-Thought.""" + result = llm_handler.generate_with_stop_condition( + caption=description, + lyrics="", + infer_type="dit", # metadata only + temperature=0.85, + cfg_scale=1.0, + use_cot_metas=True, + use_cot_caption=True, + use_cot_language=True, + user_metadata={"duration": int(duration)}, + ) + + if result.get("success") and result.get("metadata"): + meta = result["metadata"] + return { + "caption": meta.get("caption", description), + "lyrics": meta.get("lyrics", ""), + "bpm": meta.get("bpm", 100), + "keyscale": meta.get("keyscale", ""), + "language": meta.get("language", "en"), + "timesignature": meta.get("timesignature", "4"), + "duration": meta.get("duration", duration), + } + + print(f"[Simple Mode] LLM expansion failed: {result.get('error', 'unknown')}") + return None + + +def fallback_expand(description: str, duration: float = 30.0): + """Fallback: use description as caption with default parameters.""" + print(f"[Simple Mode] LLM not available. Using description as caption.") + return { + "caption": description, + "lyrics": "", + "bpm": 100, + "keyscale": "", + "language": "en", + "timesignature": "4", + "duration": duration, + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + # Target library path (for LLMHandler) + TARGET_LIB = os.path.join(os.path.dirname(__file__), "../../../../ACE-Step-1.5") + if TARGET_LIB not in sys.path: + sys.path.insert(0, TARGET_LIB) + + description = "a soft Bengali love song for a quiet evening" + duration = 30.0 + + # 1. Try to load LLM + print("=" * 60) + print("Ace-Step 1.5 — Simple Mode (LLM expansion)") + print("=" * 60) + print(f"\n[Simple Mode] Input: '{description}'") + + llm_handler, llm_ok = try_load_llm_handler( + checkpoint_dir=TARGET_LIB, + lm_model_path="acestep-5Hz-lm-1.7B", + ) + + # 2. Expand parameters + if llm_ok: + params = expand_with_llm(llm_handler, description, duration=duration) + if params is None: + params = fallback_expand(description, duration) + else: + params = fallback_expand(description, duration) + + print(f"\n[Simple Mode] Parameters:") + print(f" Caption: {params['caption'][:100]}...") + print(f" Lyrics: {len(params['lyrics'])} chars") + print(f" BPM: {params['bpm']}, Keyscale: {params['keyscale']}") + print(f" Language: {params['language']}, Time Sig: {params['timesignature']}") + print(f" Duration: {params['duration']}s") + + # 3. Load Pipeline + print(f"\n[Pipeline] Loading Ace-Step 1.5 (turbo)...") + pipe = AceStepPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="acestep-v15-turbo/model.safetensors" + ), + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="acestep-v15-turbo/model.safetensors" + ), + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors" + ), + ], + tokenizer_config=ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="Qwen3-Embedding-0.6B/" + ), + vae_config=ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="vae/" + ), + ) + + # 4. Generate + print(f"\n[Generation] Running Pipeline...") + audio = pipe( + prompt=params["caption"], + lyrics=params["lyrics"], + duration=params["duration"], + seed=42, + num_inference_steps=8, + cfg_scale=1.0, + shift=3.0, + ) + + output_path = "Ace-Step1.5-SimpleMode.wav" + sf.write(output_path, audio.cpu().numpy(), pipe.sample_rate) + print(f"\n[Done] Saved to {output_path}") + print(f" Shape: {audio.shape}, Duration: {audio.shape[-1] / pipe.sample_rate:.1f}s") + + +if __name__ == "__main__": + main() diff --git a/examples/ace_step/model_inference/Ace-Step1.5.py b/examples/ace_step/model_inference/Ace-Step1.5.py new file mode 100644 index 0000000..a9fda15 --- /dev/null +++ b/examples/ace_step/model_inference/Ace-Step1.5.py @@ -0,0 +1,67 @@ +""" +Ace-Step 1.5 — Text-to-Music (Turbo) inference example. + +Demonstrates the standard text2music pipeline with structured parameters +(caption, lyrics, duration, etc.) — no LLM expansion needed. + +For Simple Mode (LLM expands a short description), see: + - Ace-Step1.5-SimpleMode.py +""" +from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig +import torch +import soundfile as sf + + +pipe = AceStepPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="acestep-v15-turbo/model.safetensors" + ), + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="acestep-v15-turbo/model.safetensors" + ), + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors" + ), + ], + tokenizer_config=ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="Qwen3-Embedding-0.6B/" + ), + vae_config=ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="vae/" + ), +) + +prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline." +lyrics = """[Intro - Synth Brass Fanfare] + +[Verse 1] +黑夜里的风吹过耳畔 +甜蜜时光转瞬即逝 +脚步飘摇在星光上 + +[Chorus] +心电感应在震动间 +拥抱未来勇敢冒险 + +[Outro - Instrumental]""" + +audio = pipe( + prompt=prompt, + lyrics=lyrics, + duration=30.0, + seed=42, + num_inference_steps=8, + cfg_scale=1.0, + shift=3.0, +) + +sf.write("Ace-Step1.5.wav", audio.cpu().numpy(), pipe.sample_rate) +print(f"Saved to Ace-Step1.5.wav, shape: {audio.shape}, duration: {audio.shape[-1] / pipe.sample_rate:.1f}s") diff --git a/examples/ace_step/model_inference/acestep-v15-base.py b/examples/ace_step/model_inference/acestep-v15-base.py new file mode 100644 index 0000000..480a6fe --- /dev/null +++ b/examples/ace_step/model_inference/acestep-v15-base.py @@ -0,0 +1,52 @@ +""" +Ace-Step 1.5 Base (non-turbo, 24 layers) — Text-to-Music inference example. + +Uses cfg_scale=7.0 (standard CFG guidance) and more steps for higher quality. +""" +from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig +import torch +import soundfile as sf + + +pipe = AceStepPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="acestep-v15-base/model.safetensors" + ), + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="acestep-v15-base/model.safetensors" + ), + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors" + ), + ], + tokenizer_config=ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="Qwen3-Embedding-0.6B/" + ), + vae_config=ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="vae/" + ), +) + +prompt = "A cinematic orchestral piece with soaring strings and heroic brass" +lyrics = "[Intro - Orchestra]\n\n[Verse 1]\nAcross the mountains, through the valley\nA journey of a thousand miles\n\n[Chorus]\nRise above the stormy skies\nLet the music carry you" + +audio = pipe( + prompt=prompt, + lyrics=lyrics, + duration=30.0, + seed=42, + num_inference_steps=20, + cfg_scale=7.0, # Base model uses CFG + shift=3.0, +) + +sf.write("acestep-v15-base.wav", audio.cpu().numpy(), pipe.sample_rate) +print(f"Saved, shape: {audio.shape}") diff --git a/examples/ace_step/model_inference/acestep-v15-sft.py b/examples/ace_step/model_inference/acestep-v15-sft.py new file mode 100644 index 0000000..c9ec0ff --- /dev/null +++ b/examples/ace_step/model_inference/acestep-v15-sft.py @@ -0,0 +1,52 @@ +""" +Ace-Step 1.5 SFT (supervised fine-tuned, 24 layers) — Text-to-Music inference example. + +SFT variant is fine-tuned for specific music styles. +""" +from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig +import torch +import soundfile as sf + + +pipe = AceStepPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="acestep-v15-sft/model.safetensors" + ), + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="acestep-v15-sft/model.safetensors" + ), + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors" + ), + ], + tokenizer_config=ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="Qwen3-Embedding-0.6B/" + ), + vae_config=ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="vae/" + ), +) + +prompt = "A jazzy lo-fi beat with smooth saxophone and vinyl crackle, late night vibes" +lyrics = "[Intro - Vinyl crackle]\n\n[Verse 1]\nMidnight city, neon glow\nSmooth jazz flowing to and fro\n\n[Chorus]\nLay back, let the music play\nJazzy nights, dreams drift away" + +audio = pipe( + prompt=prompt, + lyrics=lyrics, + duration=30.0, + seed=42, + num_inference_steps=20, + cfg_scale=7.0, + shift=3.0, +) + +sf.write("acestep-v15-sft.wav", audio.cpu().numpy(), pipe.sample_rate) +print(f"Saved, shape: {audio.shape}") diff --git a/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py b/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py new file mode 100644 index 0000000..447f6b0 --- /dev/null +++ b/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py @@ -0,0 +1,52 @@ +""" +Ace-Step 1.5 Turbo (shift=1) — Text-to-Music inference example. + +Uses shift=1.0 (no timestep transformation) for smoother, slower denoising. +""" +from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig +import torch +import soundfile as sf + + +pipe = AceStepPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="acestep-v15-turbo/model.safetensors" + ), + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="acestep-v15-turbo/model.safetensors" + ), + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors" + ), + ], + tokenizer_config=ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="Qwen3-Embedding-0.6B/" + ), + vae_config=ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="vae/" + ), +) + +prompt = "A gentle acoustic guitar melody with soft piano accompaniment, peaceful and warm atmosphere" +lyrics = "[Verse 1]\nSunlight filtering through the trees\nA quiet moment, just the breeze\n\n[Chorus]\nPeaceful heart, open mind\nLeaving all the noise behind" + +audio = pipe( + prompt=prompt, + lyrics=lyrics, + duration=30.0, + seed=42, + num_inference_steps=8, + cfg_scale=1.0, + shift=1.0, # shift=1: no timestep transformation +) + +sf.write("acestep-v15-turbo-shift1.wav", audio.cpu().numpy(), pipe.sample_rate) +print(f"Saved, shape: {audio.shape}") diff --git a/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py b/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py new file mode 100644 index 0000000..8091500 --- /dev/null +++ b/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py @@ -0,0 +1,52 @@ +""" +Ace-Step 1.5 Turbo (shift=3) — Text-to-Music inference example. + +Uses shift=3.0 (default turbo shift) for faster denoising convergence. +""" +from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig +import torch +import soundfile as sf + + +pipe = AceStepPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="acestep-v15-turbo/model.safetensors" + ), + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="acestep-v15-turbo/model.safetensors" + ), + ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors" + ), + ], + tokenizer_config=ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="Qwen3-Embedding-0.6B/" + ), + vae_config=ModelConfig( + model_id="ACE-Step/Ace-Step1.5", + origin_file_pattern="vae/" + ), +) + +prompt = "An explosive, high-energy pop-rock track with anime theme song feel" +lyrics = "[Intro]\n\n[Verse 1]\nRunning through the neon lights\nChasing dreams across the night\n\n[Chorus]\nFeel the fire in my soul\nMusic takes complete control" + +audio = pipe( + prompt=prompt, + lyrics=lyrics, + duration=30.0, + seed=42, + num_inference_steps=8, + cfg_scale=1.0, + shift=3.0, +) + +sf.write("acestep-v15-turbo-shift3.wav", audio.cpu().numpy(), pipe.sample_rate) +print(f"Saved, shape: {audio.shape}") diff --git a/examples/ace_step/model_inference/acestep-v15-xl-base.py b/examples/ace_step/model_inference/acestep-v15-xl-base.py new file mode 100644 index 0000000..f1c5b4e --- /dev/null +++ b/examples/ace_step/model_inference/acestep-v15-xl-base.py @@ -0,0 +1,52 @@ +""" +Ace-Step 1.5 XL Base (32 layers, hidden_size=2560) — Text-to-Music inference example. + +XL variant with larger capacity for higher quality generation. +""" +from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig +import torch +import soundfile as sf + + +pipe = AceStepPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="ACE-Step/acestep-v15-xl-base", + origin_file_pattern="model-*.safetensors" + ), + ModelConfig( + model_id="ACE-Step/acestep-v15-xl-base", + origin_file_pattern="model-*.safetensors" + ), + ModelConfig( + model_id="ACE-Step/acestep-v15-xl-base", + origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors" + ), + ], + tokenizer_config=ModelConfig( + model_id="ACE-Step/acestep-v15-xl-base", + origin_file_pattern="Qwen3-Embedding-0.6B/" + ), + vae_config=ModelConfig( + model_id="ACE-Step/acestep-v15-xl-base", + origin_file_pattern="vae/" + ), +) + +prompt = "An epic symphonic metal track with double bass drums and soaring vocals" +lyrics = "[Intro - Heavy guitar riff]\n\n[Verse 1]\nSteel and thunder, fire and rain\nBurning through the endless pain\n\n[Chorus]\nRise up, break the chains\nUnleash the fire in your veins" + +audio = pipe( + prompt=prompt, + lyrics=lyrics, + duration=30.0, + seed=42, + num_inference_steps=20, + cfg_scale=7.0, + shift=3.0, +) + +sf.write("acestep-v15-xl-base.wav", audio.cpu().numpy(), pipe.sample_rate) +print(f"Saved, shape: {audio.shape}") diff --git a/examples/ace_step/model_inference/acestep-v15-xl-sft.py b/examples/ace_step/model_inference/acestep-v15-xl-sft.py new file mode 100644 index 0000000..73d54d9 --- /dev/null +++ b/examples/ace_step/model_inference/acestep-v15-xl-sft.py @@ -0,0 +1,50 @@ +""" +Ace-Step 1.5 XL SFT (32 layers, supervised fine-tuned) — Text-to-Music inference example. +""" +from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig +import torch +import soundfile as sf + + +pipe = AceStepPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="ACE-Step/acestep-v15-xl-sft", + origin_file_pattern="model-*.safetensors" + ), + ModelConfig( + model_id="ACE-Step/acestep-v15-xl-sft", + origin_file_pattern="model-*.safetensors" + ), + ModelConfig( + model_id="ACE-Step/acestep-v15-xl-sft", + origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors" + ), + ], + tokenizer_config=ModelConfig( + model_id="ACE-Step/acestep-v15-xl-sft", + origin_file_pattern="Qwen3-Embedding-0.6B/" + ), + vae_config=ModelConfig( + model_id="ACE-Step/acestep-v15-xl-sft", + origin_file_pattern="vae/" + ), +) + +prompt = "A beautiful piano ballad with lush strings and emotional vocals, cinematic feel" +lyrics = "[Intro - Solo piano]\n\n[Verse 1]\nWhispers of a distant shore\nMemories I hold so dear\n\n[Chorus]\nIn your eyes I see the dawn\nAll my fears are gone" + +audio = pipe( + prompt=prompt, + lyrics=lyrics, + duration=30.0, + seed=42, + num_inference_steps=20, + cfg_scale=7.0, + shift=3.0, +) + +sf.write("acestep-v15-xl-sft.wav", audio.cpu().numpy(), pipe.sample_rate) +print(f"Saved, shape: {audio.shape}") diff --git a/examples/ace_step/model_inference/acestep-v15-xl-turbo.py b/examples/ace_step/model_inference/acestep-v15-xl-turbo.py new file mode 100644 index 0000000..9116567 --- /dev/null +++ b/examples/ace_step/model_inference/acestep-v15-xl-turbo.py @@ -0,0 +1,52 @@ +""" +Ace-Step 1.5 XL Turbo (32 layers) — Text-to-Music inference example. + +XL turbo with fast generation (8 steps, shift=3.0, no CFG). +""" +from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig +import torch +import soundfile as sf + + +pipe = AceStepPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="ACE-Step/acestep-v15-xl-turbo", + origin_file_pattern="model-*.safetensors" + ), + ModelConfig( + model_id="ACE-Step/acestep-v15-xl-turbo", + origin_file_pattern="model-*.safetensors" + ), + ModelConfig( + model_id="ACE-Step/acestep-v15-xl-turbo", + origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors" + ), + ], + tokenizer_config=ModelConfig( + model_id="ACE-Step/acestep-v15-xl-turbo", + origin_file_pattern="Qwen3-Embedding-0.6B/" + ), + vae_config=ModelConfig( + model_id="ACE-Step/acestep-v15-xl-turbo", + origin_file_pattern="vae/" + ), +) + +prompt = "An upbeat electronic dance track with pulsing synths and driving bassline" +lyrics = "[Intro - Synth build]\n\n[Verse 1]\nFeel the rhythm in the air\nElectric beats are everywhere\n\n[Drop]\n\n[Chorus]\nDance until the break of dawn\nMove your body, carry on" + +audio = pipe( + prompt=prompt, + lyrics=lyrics, + duration=30.0, + seed=42, + num_inference_steps=8, + cfg_scale=1.0, # turbo: no CFG + shift=3.0, +) + +sf.write("acestep-v15-xl-turbo.wav", audio.cpu().numpy(), pipe.sample_rate) +print(f"Saved, shape: {audio.shape}")