mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-02 15:48:11 +00:00
add acestep models
This commit is contained in:
15
diffsynth/utils/state_dict_converters/ace_step_dit.py
Normal file
15
diffsynth/utils/state_dict_converters/ace_step_dit.py
Normal file
@@ -0,0 +1,15 @@
|
||||
def AceStepDiTStateDictConverter(state_dict):
|
||||
"""
|
||||
Convert ACE-Step DiT state dict to add 'model.' prefix for wrapper class.
|
||||
|
||||
The wrapper class has self.model = AceStepConditionGenerationModel(config),
|
||||
so all keys need to be prefixed with 'model.'
|
||||
"""
|
||||
state_dict_ = {}
|
||||
keys = state_dict.keys() if hasattr(state_dict, 'keys') else state_dict
|
||||
for k in keys:
|
||||
v = state_dict[k]
|
||||
if not k.startswith("model."):
|
||||
k = "model." + k
|
||||
state_dict_[k] = v
|
||||
return state_dict_
|
||||
@@ -0,0 +1,19 @@
|
||||
def AceStepTextEncoderStateDictConverter(state_dict):
|
||||
"""
|
||||
将 ACE-Step Text Encoder 权重添加 model. 前缀
|
||||
|
||||
Args:
|
||||
state_dict: 原始的 state dict(可能是 dict 或 DiskMap)
|
||||
|
||||
Returns:
|
||||
转换后的 state dict,所有 key 添加 "model." 前缀
|
||||
"""
|
||||
state_dict_ = {}
|
||||
# 处理 DiskMap 或普通 dict
|
||||
keys = state_dict.keys() if hasattr(state_dict, 'keys') else state_dict
|
||||
for k in keys:
|
||||
v = state_dict[k]
|
||||
if not k.startswith("model."):
|
||||
k = "model." + k
|
||||
state_dict_[k] = v
|
||||
return state_dict_
|
||||
Reference in New Issue
Block a user