model-code

This commit is contained in:
mi804
2026-04-17 17:06:26 +08:00
parent 079e51c9f3
commit 36c203da57
23 changed files with 4230 additions and 2 deletions

View File

@@ -0,0 +1,48 @@
"""
State dict converter for ACE-Step Conditioner model.
The original checkpoint stores all model weights in a single file
(nested in AceStepConditionGenerationModel). The Conditioner weights are
prefixed with 'encoder.'.
This converter extracts only keys starting with 'encoder.' and strips
the prefix to match the standalone AceStepConditionEncoder in DiffSynth.
"""
def ace_step_conditioner_converter(state_dict):
"""
Convert ACE-Step Conditioner checkpoint keys to DiffSynth format.
参数 state_dict 是 DiskMap 类型。
遍历时key 是 key 名state_dict[key] 获取实际值。
Original checkpoint contains all model weights under prefixes:
- decoder.* (DiT)
- encoder.* (Conditioner)
- tokenizer.* (Audio Tokenizer)
- detokenizer.* (Audio Detokenizer)
- null_condition_emb (CFG null embedding)
This extracts only 'encoder.' keys and strips the prefix.
Example mapping:
encoder.lyric_encoder.layers.0.self_attn.q_proj.weight -> lyric_encoder.layers.0.self_attn.q_proj.weight
encoder.attention_pooler.layers.0.self_attn.q_proj.weight -> attention_pooler.layers.0.self_attn.q_proj.weight
encoder.timbre_encoder.layers.0.self_attn.q_proj.weight -> timbre_encoder.layers.0.self_attn.q_proj.weight
encoder.audio_tokenizer.audio_acoustic_proj.weight -> audio_tokenizer.audio_acoustic_proj.weight
encoder.detokenizer.layers.0.self_attn.q_proj.weight -> detokenizer.layers.0.self_attn.q_proj.weight
"""
new_state_dict = {}
prefix = "encoder."
for key in state_dict:
if key.startswith(prefix):
new_key = key[len(prefix):]
new_state_dict[new_key] = state_dict[key]
# Extract null_condition_emb from top level (used for CFG negative condition)
if "null_condition_emb" in state_dict:
new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"]
return new_state_dict

View File

@@ -0,0 +1,43 @@
"""
State dict converter for ACE-Step DiT model.
The original checkpoint stores all model weights in a single file
(nested in AceStepConditionGenerationModel). The DiT weights are
prefixed with 'decoder.'.
This converter extracts only keys starting with 'decoder.' and strips
the prefix to match the standalone AceStepDiTModel in DiffSynth.
"""
def ace_step_dit_converter(state_dict):
"""
Convert ACE-Step DiT checkpoint keys to DiffSynth format.
参数 state_dict 是 DiskMap 类型。
遍历时key 是 key 名state_dict[key] 获取实际值。
Original checkpoint contains all model weights under prefixes:
- decoder.* (DiT)
- encoder.* (Conditioner)
- tokenizer.* (Audio Tokenizer)
- detokenizer.* (Audio Detokenizer)
- null_condition_emb (CFG null embedding)
This extracts only 'decoder.' keys and strips the prefix.
Example mapping:
decoder.layers.0.self_attn.q_proj.weight -> layers.0.self_attn.q_proj.weight
decoder.proj_in.0.linear_1.weight -> proj_in.0.linear_1.weight
decoder.time_embed.linear_1.weight -> time_embed.linear_1.weight
decoder.rotary_emb.inv_freq -> rotary_emb.inv_freq
"""
new_state_dict = {}
prefix = "decoder."
for key in state_dict:
if key.startswith(prefix):
new_key = key[len(prefix):]
new_state_dict[new_key] = state_dict[key]
return new_state_dict

View File

@@ -0,0 +1,55 @@
"""
State dict converter for ACE-Step LLM (Qwen3-based).
The safetensors file stores Qwen3 model weights. Different checkpoints
may have different key formats:
- Qwen3ForCausalLM format: model.embed_tokens.weight, model.layers.0.*
- Qwen3Model format: embed_tokens.weight, layers.0.*
Qwen3ForCausalLM wraps a .model attribute (Qwen3Model), so its
state_dict() has keys:
model.model.embed_tokens.weight
model.model.layers.0.self_attn.q_proj.weight
model.model.norm.weight
model.lm_head.weight (tied to model.model.embed_tokens)
This converter normalizes all keys to the Qwen3ForCausalLM format.
Example mapping:
model.embed_tokens.weight -> model.model.embed_tokens.weight
embed_tokens.weight -> model.model.embed_tokens.weight
model.layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
model.norm.weight -> model.model.norm.weight
norm.weight -> model.model.norm.weight
"""
def ace_step_lm_converter(state_dict):
"""
Convert ACE-Step LLM checkpoint keys to match Qwen3ForCausalLM state dict.
参数 state_dict 是 DiskMap 类型。
遍历时key 是 key 名state_dict[key] 获取实际值。
"""
new_state_dict = {}
model_prefix = "model."
nested_prefix = "model.model."
for key in state_dict:
if key.startswith(nested_prefix):
# Already has model.model., keep as is
new_key = key
elif key.startswith(model_prefix):
# Has model., add another model.
new_key = "model." + key
else:
# No prefix, add model.model.
new_key = "model.model." + key
new_state_dict[new_key] = state_dict[key]
# Handle tied word embeddings: lm_head.weight shares with embed_tokens
if "model.model.embed_tokens.weight" in new_state_dict:
new_state_dict["model.lm_head.weight"] = new_state_dict["model.model.embed_tokens.weight"]
return new_state_dict

View File

@@ -0,0 +1,39 @@
"""
State dict converter for ACE-Step Text Encoder (Qwen3-Embedding-0.6B).
The safetensors stores Qwen3Model weights with keys:
embed_tokens.weight
layers.0.self_attn.q_proj.weight
norm.weight
AceStepTextEncoder wraps a .model attribute (Qwen3Model), so its
state_dict() has keys with 'model.' prefix:
model.embed_tokens.weight
model.layers.0.self_attn.q_proj.weight
model.norm.weight
This converter adds 'model.' prefix to match the nested structure.
"""
def ace_step_text_encoder_converter(state_dict):
"""
Convert ACE-Step Text Encoder checkpoint keys to match Qwen3Model wrapped state dict.
参数 state_dict 是 DiskMap 类型。
遍历时key 是 key 名state_dict[key] 获取实际值。
"""
new_state_dict = {}
prefix = "model."
nested_prefix = "model.model."
for key in state_dict:
if key.startswith(nested_prefix):
new_key = key
elif key.startswith(prefix):
new_key = "model." + key
else:
new_key = "model." + key
new_state_dict[new_key] = state_dict[key]
return new_state_dict

View File

@@ -0,0 +1,27 @@
"""
State dict converter for ACE-Step Tokenizer model.
The original checkpoint stores tokenizer and detokenizer weights at the top level:
- tokenizer.* (AceStepAudioTokenizer: audio_acoustic_proj, attention_pooler, quantizer)
- detokenizer.* (AudioTokenDetokenizer: embed_tokens, layers, proj_out)
These map directly to the AceStepTokenizer class which wraps both as
self.tokenizer and self.detokenizer submodules.
"""
def ace_step_tokenizer_converter(state_dict):
"""
Convert ACE-Step Tokenizer checkpoint keys to DiffSynth format.
The checkpoint keys `tokenizer.*` and `detokenizer.*` already match
the DiffSynth AceStepTokenizer module structure (self.tokenizer, self.detokenizer).
No key remapping needed — just extract the relevant keys.
"""
new_state_dict = {}
for key in state_dict:
if key.startswith("tokenizer.") or key.startswith("detokenizer."):
new_state_dict[key] = state_dict[key]
return new_state_dict