mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
model-code
This commit is contained in:
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
State dict converter for ACE-Step Conditioner model.
|
||||
|
||||
The original checkpoint stores all model weights in a single file
|
||||
(nested in AceStepConditionGenerationModel). The Conditioner weights are
|
||||
prefixed with 'encoder.'.
|
||||
|
||||
This converter extracts only keys starting with 'encoder.' and strips
|
||||
the prefix to match the standalone AceStepConditionEncoder in DiffSynth.
|
||||
"""
|
||||
|
||||
|
||||
def ace_step_conditioner_converter(state_dict):
|
||||
"""
|
||||
Convert ACE-Step Conditioner checkpoint keys to DiffSynth format.
|
||||
|
||||
参数 state_dict 是 DiskMap 类型。
|
||||
遍历时,key 是 key 名,state_dict[key] 获取实际值。
|
||||
|
||||
Original checkpoint contains all model weights under prefixes:
|
||||
- decoder.* (DiT)
|
||||
- encoder.* (Conditioner)
|
||||
- tokenizer.* (Audio Tokenizer)
|
||||
- detokenizer.* (Audio Detokenizer)
|
||||
- null_condition_emb (CFG null embedding)
|
||||
|
||||
This extracts only 'encoder.' keys and strips the prefix.
|
||||
|
||||
Example mapping:
|
||||
encoder.lyric_encoder.layers.0.self_attn.q_proj.weight -> lyric_encoder.layers.0.self_attn.q_proj.weight
|
||||
encoder.attention_pooler.layers.0.self_attn.q_proj.weight -> attention_pooler.layers.0.self_attn.q_proj.weight
|
||||
encoder.timbre_encoder.layers.0.self_attn.q_proj.weight -> timbre_encoder.layers.0.self_attn.q_proj.weight
|
||||
encoder.audio_tokenizer.audio_acoustic_proj.weight -> audio_tokenizer.audio_acoustic_proj.weight
|
||||
encoder.detokenizer.layers.0.self_attn.q_proj.weight -> detokenizer.layers.0.self_attn.q_proj.weight
|
||||
"""
|
||||
new_state_dict = {}
|
||||
prefix = "encoder."
|
||||
|
||||
for key in state_dict:
|
||||
if key.startswith(prefix):
|
||||
new_key = key[len(prefix):]
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
|
||||
# Extract null_condition_emb from top level (used for CFG negative condition)
|
||||
if "null_condition_emb" in state_dict:
|
||||
new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"]
|
||||
|
||||
return new_state_dict
|
||||
43
diffsynth/utils/state_dict_converters/ace_step_dit.py
Normal file
43
diffsynth/utils/state_dict_converters/ace_step_dit.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
State dict converter for ACE-Step DiT model.
|
||||
|
||||
The original checkpoint stores all model weights in a single file
|
||||
(nested in AceStepConditionGenerationModel). The DiT weights are
|
||||
prefixed with 'decoder.'.
|
||||
|
||||
This converter extracts only keys starting with 'decoder.' and strips
|
||||
the prefix to match the standalone AceStepDiTModel in DiffSynth.
|
||||
"""
|
||||
|
||||
|
||||
def ace_step_dit_converter(state_dict):
|
||||
"""
|
||||
Convert ACE-Step DiT checkpoint keys to DiffSynth format.
|
||||
|
||||
参数 state_dict 是 DiskMap 类型。
|
||||
遍历时,key 是 key 名,state_dict[key] 获取实际值。
|
||||
|
||||
Original checkpoint contains all model weights under prefixes:
|
||||
- decoder.* (DiT)
|
||||
- encoder.* (Conditioner)
|
||||
- tokenizer.* (Audio Tokenizer)
|
||||
- detokenizer.* (Audio Detokenizer)
|
||||
- null_condition_emb (CFG null embedding)
|
||||
|
||||
This extracts only 'decoder.' keys and strips the prefix.
|
||||
|
||||
Example mapping:
|
||||
decoder.layers.0.self_attn.q_proj.weight -> layers.0.self_attn.q_proj.weight
|
||||
decoder.proj_in.0.linear_1.weight -> proj_in.0.linear_1.weight
|
||||
decoder.time_embed.linear_1.weight -> time_embed.linear_1.weight
|
||||
decoder.rotary_emb.inv_freq -> rotary_emb.inv_freq
|
||||
"""
|
||||
new_state_dict = {}
|
||||
prefix = "decoder."
|
||||
|
||||
for key in state_dict:
|
||||
if key.startswith(prefix):
|
||||
new_key = key[len(prefix):]
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
|
||||
return new_state_dict
|
||||
55
diffsynth/utils/state_dict_converters/ace_step_lm.py
Normal file
55
diffsynth/utils/state_dict_converters/ace_step_lm.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
State dict converter for ACE-Step LLM (Qwen3-based).
|
||||
|
||||
The safetensors file stores Qwen3 model weights. Different checkpoints
|
||||
may have different key formats:
|
||||
- Qwen3ForCausalLM format: model.embed_tokens.weight, model.layers.0.*
|
||||
- Qwen3Model format: embed_tokens.weight, layers.0.*
|
||||
|
||||
Qwen3ForCausalLM wraps a .model attribute (Qwen3Model), so its
|
||||
state_dict() has keys:
|
||||
model.model.embed_tokens.weight
|
||||
model.model.layers.0.self_attn.q_proj.weight
|
||||
model.model.norm.weight
|
||||
model.lm_head.weight (tied to model.model.embed_tokens)
|
||||
|
||||
This converter normalizes all keys to the Qwen3ForCausalLM format.
|
||||
|
||||
Example mapping:
|
||||
model.embed_tokens.weight -> model.model.embed_tokens.weight
|
||||
embed_tokens.weight -> model.model.embed_tokens.weight
|
||||
model.layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
|
||||
layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
|
||||
model.norm.weight -> model.model.norm.weight
|
||||
norm.weight -> model.model.norm.weight
|
||||
"""
|
||||
|
||||
|
||||
def ace_step_lm_converter(state_dict):
|
||||
"""
|
||||
Convert ACE-Step LLM checkpoint keys to match Qwen3ForCausalLM state dict.
|
||||
|
||||
参数 state_dict 是 DiskMap 类型。
|
||||
遍历时,key 是 key 名,state_dict[key] 获取实际值。
|
||||
"""
|
||||
new_state_dict = {}
|
||||
model_prefix = "model."
|
||||
nested_prefix = "model.model."
|
||||
|
||||
for key in state_dict:
|
||||
if key.startswith(nested_prefix):
|
||||
# Already has model.model., keep as is
|
||||
new_key = key
|
||||
elif key.startswith(model_prefix):
|
||||
# Has model., add another model.
|
||||
new_key = "model." + key
|
||||
else:
|
||||
# No prefix, add model.model.
|
||||
new_key = "model.model." + key
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
|
||||
# Handle tied word embeddings: lm_head.weight shares with embed_tokens
|
||||
if "model.model.embed_tokens.weight" in new_state_dict:
|
||||
new_state_dict["model.lm_head.weight"] = new_state_dict["model.model.embed_tokens.weight"]
|
||||
|
||||
return new_state_dict
|
||||
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
State dict converter for ACE-Step Text Encoder (Qwen3-Embedding-0.6B).
|
||||
|
||||
The safetensors stores Qwen3Model weights with keys:
|
||||
embed_tokens.weight
|
||||
layers.0.self_attn.q_proj.weight
|
||||
norm.weight
|
||||
|
||||
AceStepTextEncoder wraps a .model attribute (Qwen3Model), so its
|
||||
state_dict() has keys with 'model.' prefix:
|
||||
model.embed_tokens.weight
|
||||
model.layers.0.self_attn.q_proj.weight
|
||||
model.norm.weight
|
||||
|
||||
This converter adds 'model.' prefix to match the nested structure.
|
||||
"""
|
||||
|
||||
|
||||
def ace_step_text_encoder_converter(state_dict):
|
||||
"""
|
||||
Convert ACE-Step Text Encoder checkpoint keys to match Qwen3Model wrapped state dict.
|
||||
|
||||
参数 state_dict 是 DiskMap 类型。
|
||||
遍历时,key 是 key 名,state_dict[key] 获取实际值。
|
||||
"""
|
||||
new_state_dict = {}
|
||||
prefix = "model."
|
||||
nested_prefix = "model.model."
|
||||
|
||||
for key in state_dict:
|
||||
if key.startswith(nested_prefix):
|
||||
new_key = key
|
||||
elif key.startswith(prefix):
|
||||
new_key = "model." + key
|
||||
else:
|
||||
new_key = "model." + key
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
|
||||
return new_state_dict
|
||||
27
diffsynth/utils/state_dict_converters/ace_step_tokenizer.py
Normal file
27
diffsynth/utils/state_dict_converters/ace_step_tokenizer.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
State dict converter for ACE-Step Tokenizer model.
|
||||
|
||||
The original checkpoint stores tokenizer and detokenizer weights at the top level:
|
||||
- tokenizer.* (AceStepAudioTokenizer: audio_acoustic_proj, attention_pooler, quantizer)
|
||||
- detokenizer.* (AudioTokenDetokenizer: embed_tokens, layers, proj_out)
|
||||
|
||||
These map directly to the AceStepTokenizer class which wraps both as
|
||||
self.tokenizer and self.detokenizer submodules.
|
||||
"""
|
||||
|
||||
|
||||
def ace_step_tokenizer_converter(state_dict):
|
||||
"""
|
||||
Convert ACE-Step Tokenizer checkpoint keys to DiffSynth format.
|
||||
|
||||
The checkpoint keys `tokenizer.*` and `detokenizer.*` already match
|
||||
the DiffSynth AceStepTokenizer module structure (self.tokenizer, self.detokenizer).
|
||||
No key remapping needed — just extract the relevant keys.
|
||||
"""
|
||||
new_state_dict = {}
|
||||
|
||||
for key in state_dict:
|
||||
if key.startswith("tokenizer.") or key.startswith("detokenizer."):
|
||||
new_state_dict[key] = state_dict[key]
|
||||
|
||||
return new_state_dict
|
||||
Reference in New Issue
Block a user