mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
style
This commit is contained in:
@@ -925,7 +925,7 @@ ace_step_series = [
|
||||
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
||||
"model_name": "ace_step_dit",
|
||||
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
|
||||
},
|
||||
# === XL DiT variants (32 layers, hidden_size=2560) ===
|
||||
# Covers: xl-base, xl-sft, xl-turbo
|
||||
@@ -934,7 +934,7 @@ ace_step_series = [
|
||||
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
||||
"model_name": "ace_step_dit",
|
||||
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
|
||||
"extra_kwargs": {
|
||||
"hidden_size": 2560,
|
||||
"intermediate_size": 9728,
|
||||
@@ -952,7 +952,7 @@ ace_step_series = [
|
||||
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
||||
"model_name": "ace_step_conditioner",
|
||||
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
|
||||
},
|
||||
# === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) ===
|
||||
{
|
||||
@@ -960,7 +960,7 @@ ace_step_series = [
|
||||
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
||||
"model_name": "ace_step_conditioner",
|
||||
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
|
||||
},
|
||||
# === Qwen3-Embedding (text encoder) ===
|
||||
{
|
||||
@@ -968,7 +968,7 @@ ace_step_series = [
|
||||
"model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
|
||||
"model_name": "ace_step_text_encoder",
|
||||
"model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.ace_step_text_encoder_converter",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.AceStepTextEncoderStateDictConverter",
|
||||
},
|
||||
# === VAE (AutoencoderOobleck CNN) ===
|
||||
{
|
||||
@@ -983,7 +983,7 @@ ace_step_series = [
|
||||
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
||||
"model_name": "ace_step_tokenizer",
|
||||
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
|
||||
},
|
||||
# === XL Tokenizer (XL models share same tokenizer architecture) ===
|
||||
{
|
||||
@@ -991,8 +991,11 @@ ace_step_series = [
|
||||
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
||||
"model_name": "ace_step_tokenizer",
|
||||
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
|
||||
},
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
|
||||
MODEL_CONFIGS = (
|
||||
qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
|
||||
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
|
||||
)
|
||||
|
||||
@@ -3,12 +3,10 @@ ACE-Step Pipeline for DiffSynth-Studio.
|
||||
|
||||
Text-to-Music generation pipeline using ACE-Step 1.5 model.
|
||||
"""
|
||||
import re
|
||||
import torch
|
||||
import re, torch
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import math
|
||||
import random, math
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
@@ -39,7 +37,7 @@ class AceStepPipeline(BasePipeline):
|
||||
self.conditioner: AceStepConditionEncoder = None
|
||||
self.dit: AceStepDiTModel = None
|
||||
self.vae: AceStepVAE = None
|
||||
self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer)
|
||||
self.tokenizer_model: AceStepTokenizer = None
|
||||
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
@@ -65,7 +63,6 @@ class AceStepPipeline(BasePipeline):
|
||||
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
vram_limit: float = None,
|
||||
):
|
||||
"""Load pipeline from pretrained checkpoints."""
|
||||
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
@@ -102,7 +99,7 @@ class AceStepPipeline(BasePipeline):
|
||||
reference_audios: List[torch.Tensor] = None,
|
||||
# Source audio
|
||||
src_audio: torch.Tensor = None,
|
||||
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
|
||||
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
|
||||
audio_cover_strength: float = 1.0,
|
||||
# Audio codes
|
||||
audio_code_string: Optional[str] = None,
|
||||
@@ -115,7 +112,7 @@ class AceStepPipeline(BasePipeline):
|
||||
bpm: Optional[int] = 100,
|
||||
keyscale: Optional[str] = "B minor",
|
||||
timesignature: Optional[str] = "4",
|
||||
vocal_language: Optional[str] = 'unknown',
|
||||
vocal_language: Optional[str] = "unknown",
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
@@ -126,10 +123,10 @@ class AceStepPipeline(BasePipeline):
|
||||
# Progress
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# 1. Scheduler
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
|
||||
|
||||
# 2. 三字典输入
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt, "positive": True}
|
||||
inputs_nega = {"positive": False}
|
||||
inputs_shared = {
|
||||
@@ -147,13 +144,12 @@ class AceStepPipeline(BasePipeline):
|
||||
"shift": shift,
|
||||
}
|
||||
|
||||
# 3. Unit 链执行
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||
)
|
||||
|
||||
# 4. Denoise loop
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
@@ -164,17 +160,17 @@ class AceStepPipeline(BasePipeline):
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id,
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
|
||||
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
inputs_shared["latents"] = self.step(
|
||||
self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
|
||||
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared,
|
||||
)
|
||||
|
||||
# 5. VAE 解码
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
|
||||
latents = inputs_shared["latents"].transpose(1, 2)
|
||||
vae_output = self.vae.decode(latents)
|
||||
# VAE returns OobleckDecoderOutput with .sample attribute
|
||||
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
|
||||
audio_output = self.normalize_audio(audio_output, target_db=-1.0)
|
||||
audio_output = self.normalize_audio(vae_output, target_db=-1.0)
|
||||
audio = self.output_audio_format_check(audio_output)
|
||||
self.load_models_to_device([])
|
||||
return audio
|
||||
@@ -188,7 +184,9 @@ class AceStepPipeline(BasePipeline):
|
||||
return audio * gain
|
||||
|
||||
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
|
||||
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0 or inputs_shared.get("shared_noncover", None) is None:
|
||||
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0:
|
||||
return
|
||||
if inputs_shared.get("shared_noncover", None) is None:
|
||||
return
|
||||
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
|
||||
if progress_id >= cover_steps:
|
||||
@@ -312,7 +310,10 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
||||
def process(self, pipe, reference_audios):
|
||||
if reference_audios is not None:
|
||||
pipe.load_models_to_device(['vae'])
|
||||
reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios]
|
||||
reference_audios = [
|
||||
self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
for reference_audio in reference_audios
|
||||
]
|
||||
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
|
||||
else:
|
||||
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
|
||||
@@ -357,7 +358,6 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
||||
|
||||
|
||||
class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
@@ -378,7 +378,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
|
||||
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
|
||||
if inputs_shared["cfg_scale"] != 1.0:
|
||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
|
||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(
|
||||
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device,
|
||||
)
|
||||
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
||||
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
|
||||
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
|
||||
@@ -396,7 +398,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
|
||||
if inputs_shared["cfg_scale"] != 1.0:
|
||||
inputs_shared["nega_noncover"] = {
|
||||
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device),
|
||||
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(
|
||||
dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device,
|
||||
),
|
||||
"encoder_attention_mask": encoder_attention_mask_noncover,
|
||||
}
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
@@ -483,7 +487,7 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
# use audio_cede_string to get src_latents.
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
code_ids = self._parse_audio_code_string(audio_code_string)
|
||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer.to(device=pipe.device)
|
||||
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
|
||||
codes = quantizer.get_codes_from_indices(indices)
|
||||
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
|
||||
|
||||
@@ -1,38 +1,4 @@
|
||||
"""
|
||||
State dict converter for ACE-Step Conditioner model.
|
||||
|
||||
The original checkpoint stores all model weights in a single file
|
||||
(nested in AceStepConditionGenerationModel). The Conditioner weights are
|
||||
prefixed with 'encoder.'.
|
||||
|
||||
This converter extracts only keys starting with 'encoder.' and strips
|
||||
the prefix to match the standalone AceStepConditionEncoder in DiffSynth.
|
||||
"""
|
||||
|
||||
|
||||
def ace_step_conditioner_converter(state_dict):
|
||||
"""
|
||||
Convert ACE-Step Conditioner checkpoint keys to DiffSynth format.
|
||||
|
||||
参数 state_dict 是 DiskMap 类型。
|
||||
遍历时,key 是 key 名,state_dict[key] 获取实际值。
|
||||
|
||||
Original checkpoint contains all model weights under prefixes:
|
||||
- decoder.* (DiT)
|
||||
- encoder.* (Conditioner)
|
||||
- tokenizer.* (Audio Tokenizer)
|
||||
- detokenizer.* (Audio Detokenizer)
|
||||
- null_condition_emb (CFG null embedding)
|
||||
|
||||
This extracts only 'encoder.' keys and strips the prefix.
|
||||
|
||||
Example mapping:
|
||||
encoder.lyric_encoder.layers.0.self_attn.q_proj.weight -> lyric_encoder.layers.0.self_attn.q_proj.weight
|
||||
encoder.attention_pooler.layers.0.self_attn.q_proj.weight -> attention_pooler.layers.0.self_attn.q_proj.weight
|
||||
encoder.timbre_encoder.layers.0.self_attn.q_proj.weight -> timbre_encoder.layers.0.self_attn.q_proj.weight
|
||||
encoder.audio_tokenizer.audio_acoustic_proj.weight -> audio_tokenizer.audio_acoustic_proj.weight
|
||||
encoder.detokenizer.layers.0.self_attn.q_proj.weight -> detokenizer.layers.0.self_attn.q_proj.weight
|
||||
"""
|
||||
def AceStepConditionEncoderStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
prefix = "encoder."
|
||||
|
||||
@@ -41,7 +7,6 @@ def ace_step_conditioner_converter(state_dict):
|
||||
new_key = key[len(prefix):]
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
|
||||
# Extract null_condition_emb from top level (used for CFG negative condition)
|
||||
if "null_condition_emb" in state_dict:
|
||||
new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"]
|
||||
|
||||
|
||||
@@ -1,37 +1,4 @@
|
||||
"""
|
||||
State dict converter for ACE-Step DiT model.
|
||||
|
||||
The original checkpoint stores all model weights in a single file
|
||||
(nested in AceStepConditionGenerationModel). The DiT weights are
|
||||
prefixed with 'decoder.'.
|
||||
|
||||
This converter extracts only keys starting with 'decoder.' and strips
|
||||
the prefix to match the standalone AceStepDiTModel in DiffSynth.
|
||||
"""
|
||||
|
||||
|
||||
def ace_step_dit_converter(state_dict):
|
||||
"""
|
||||
Convert ACE-Step DiT checkpoint keys to DiffSynth format.
|
||||
|
||||
参数 state_dict 是 DiskMap 类型。
|
||||
遍历时,key 是 key 名,state_dict[key] 获取实际值。
|
||||
|
||||
Original checkpoint contains all model weights under prefixes:
|
||||
- decoder.* (DiT)
|
||||
- encoder.* (Conditioner)
|
||||
- tokenizer.* (Audio Tokenizer)
|
||||
- detokenizer.* (Audio Detokenizer)
|
||||
- null_condition_emb (CFG null embedding)
|
||||
|
||||
This extracts only 'decoder.' keys and strips the prefix.
|
||||
|
||||
Example mapping:
|
||||
decoder.layers.0.self_attn.q_proj.weight -> layers.0.self_attn.q_proj.weight
|
||||
decoder.proj_in.0.linear_1.weight -> proj_in.0.linear_1.weight
|
||||
decoder.time_embed.linear_1.weight -> time_embed.linear_1.weight
|
||||
decoder.rotary_emb.inv_freq -> rotary_emb.inv_freq
|
||||
"""
|
||||
def AceStepDiTModelStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
prefix = "decoder."
|
||||
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
"""
|
||||
State dict converter for ACE-Step LLM (Qwen3-based).
|
||||
|
||||
The safetensors file stores Qwen3 model weights. Different checkpoints
|
||||
may have different key formats:
|
||||
- Qwen3ForCausalLM format: model.embed_tokens.weight, model.layers.0.*
|
||||
- Qwen3Model format: embed_tokens.weight, layers.0.*
|
||||
|
||||
Qwen3ForCausalLM wraps a .model attribute (Qwen3Model), so its
|
||||
state_dict() has keys:
|
||||
model.model.embed_tokens.weight
|
||||
model.model.layers.0.self_attn.q_proj.weight
|
||||
model.model.norm.weight
|
||||
model.lm_head.weight (tied to model.model.embed_tokens)
|
||||
|
||||
This converter normalizes all keys to the Qwen3ForCausalLM format.
|
||||
|
||||
Example mapping:
|
||||
model.embed_tokens.weight -> model.model.embed_tokens.weight
|
||||
embed_tokens.weight -> model.model.embed_tokens.weight
|
||||
model.layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
|
||||
layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
|
||||
model.norm.weight -> model.model.norm.weight
|
||||
norm.weight -> model.model.norm.weight
|
||||
"""
|
||||
|
||||
|
||||
def ace_step_lm_converter(state_dict):
|
||||
"""
|
||||
Convert ACE-Step LLM checkpoint keys to match Qwen3ForCausalLM state dict.
|
||||
|
||||
参数 state_dict 是 DiskMap 类型。
|
||||
遍历时,key 是 key 名,state_dict[key] 获取实际值。
|
||||
"""
|
||||
new_state_dict = {}
|
||||
model_prefix = "model."
|
||||
nested_prefix = "model.model."
|
||||
|
||||
for key in state_dict:
|
||||
if key.startswith(nested_prefix):
|
||||
# Already has model.model., keep as is
|
||||
new_key = key
|
||||
elif key.startswith(model_prefix):
|
||||
# Has model., add another model.
|
||||
new_key = "model." + key
|
||||
else:
|
||||
# No prefix, add model.model.
|
||||
new_key = "model.model." + key
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
|
||||
# Handle tied word embeddings: lm_head.weight shares with embed_tokens
|
||||
if "model.model.embed_tokens.weight" in new_state_dict:
|
||||
new_state_dict["model.lm_head.weight"] = new_state_dict["model.model.embed_tokens.weight"]
|
||||
|
||||
return new_state_dict
|
||||
@@ -1,28 +1,4 @@
|
||||
"""
|
||||
State dict converter for ACE-Step Text Encoder (Qwen3-Embedding-0.6B).
|
||||
|
||||
The safetensors stores Qwen3Model weights with keys:
|
||||
embed_tokens.weight
|
||||
layers.0.self_attn.q_proj.weight
|
||||
norm.weight
|
||||
|
||||
AceStepTextEncoder wraps a .model attribute (Qwen3Model), so its
|
||||
state_dict() has keys with 'model.' prefix:
|
||||
model.embed_tokens.weight
|
||||
model.layers.0.self_attn.q_proj.weight
|
||||
model.norm.weight
|
||||
|
||||
This converter adds 'model.' prefix to match the nested structure.
|
||||
"""
|
||||
|
||||
|
||||
def ace_step_text_encoder_converter(state_dict):
|
||||
"""
|
||||
Convert ACE-Step Text Encoder checkpoint keys to match Qwen3Model wrapped state dict.
|
||||
|
||||
参数 state_dict 是 DiskMap 类型。
|
||||
遍历时,key 是 key 名,state_dict[key] 获取实际值。
|
||||
"""
|
||||
def AceStepTextEncoderStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
prefix = "model."
|
||||
nested_prefix = "model.model."
|
||||
|
||||
@@ -1,23 +1,4 @@
|
||||
"""
|
||||
State dict converter for ACE-Step Tokenizer model.
|
||||
|
||||
The original checkpoint stores tokenizer and detokenizer weights at the top level:
|
||||
- tokenizer.* (AceStepAudioTokenizer: audio_acoustic_proj, attention_pooler, quantizer)
|
||||
- detokenizer.* (AudioTokenDetokenizer: embed_tokens, layers, proj_out)
|
||||
|
||||
These map directly to the AceStepTokenizer class which wraps both as
|
||||
self.tokenizer and self.detokenizer submodules.
|
||||
"""
|
||||
|
||||
|
||||
def ace_step_tokenizer_converter(state_dict):
|
||||
"""
|
||||
Convert ACE-Step Tokenizer checkpoint keys to DiffSynth format.
|
||||
|
||||
The checkpoint keys `tokenizer.*` and `detokenizer.*` already match
|
||||
the DiffSynth AceStepTokenizer module structure (self.tokenizer, self.detokenizer).
|
||||
No key remapping needed — just extract the relevant keys.
|
||||
"""
|
||||
def AceStepTokenizerStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
|
||||
for key in state_dict:
|
||||
|
||||
Reference in New Issue
Block a user